mirror of https://github.com/apache/lucene.git
First-class random access API for KnnVectorValues (#13779)
This commit is contained in:
parent
7b4b0238d7
commit
6053e1e313
|
@ -18,10 +18,10 @@
|
|||
package org.apache.lucene.analysis.synonym.word2vec;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.BytesRefHash;
|
||||
import org.apache.lucene.util.TermAndVector;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
|
||||
|
@ -29,7 +29,7 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class Word2VecModel implements RandomAccessVectorValues.Floats {
|
||||
public class Word2VecModel extends FloatVectorValues {
|
||||
|
||||
private final int dictionarySize;
|
||||
private final int vectorDimension;
|
||||
|
|
|
@ -22,10 +22,10 @@ import java.util.Locale;
|
|||
import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
|
||||
|
@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
private final Lucene90NeighborArray scratch;
|
||||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues.Floats vectorValues;
|
||||
private final FloatVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final Lucene90BoundsChecker bound;
|
||||
final Lucene90OnHeapHnswGraph hnsw;
|
||||
|
@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private final RandomAccessVectorValues.Floats buildVectors;
|
||||
private final FloatVectorValues buildVectors;
|
||||
|
||||
/**
|
||||
* Reads all the vectors from vector values, builds a graph connecting them by their dense
|
||||
|
@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene90HnswGraphBuilder(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
FloatVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
|
@ -97,14 +97,14 @@ public final class Lucene90HnswGraphBuilder {
|
|||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
|
||||
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies
|
||||
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* returned values.
|
||||
*
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
|
||||
public Lucene90OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
@ -230,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
float[] candidate,
|
||||
float score,
|
||||
Lucene90NeighborArray neighbors,
|
||||
RandomAccessVectorValues.Floats vectorValues)
|
||||
FloatVectorValues vectorValues)
|
||||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.lucene.backward_codecs.lucene90;
|
|||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.SplittableRandom;
|
||||
|
@ -34,7 +33,6 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
|
@ -44,7 +42,6 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Reads vectors from the index segments along with index data structures supporting KNN search.
|
||||
|
@ -355,8 +352,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
static class OffHeapFloatVectorValues extends FloatVectorValues
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
static class OffHeapFloatVectorValues extends FloatVectorValues {
|
||||
|
||||
final int dimension;
|
||||
final int[] ordToDoc;
|
||||
|
@ -367,9 +363,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
final float[] value;
|
||||
final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
int ord = -1;
|
||||
int doc = -1;
|
||||
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int[] ordToDoc,
|
||||
|
@ -394,42 +387,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
return ordToDoc.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
if (++ord >= size()) {
|
||||
doc = NO_MORE_DOCS;
|
||||
} else {
|
||||
doc = ordToDoc[ord];
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
assert docID() < target;
|
||||
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
|
||||
if (ord < 0) {
|
||||
ord = -(ord + 1);
|
||||
}
|
||||
assert ord <= ordToDoc.length;
|
||||
if (ord == ordToDoc.length) {
|
||||
doc = NO_MORE_DOCS;
|
||||
} else {
|
||||
doc = ordToDoc[ord];
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() {
|
||||
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
|
||||
|
@ -446,21 +403,32 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return ordToDoc[ord];
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createSparseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
if (size() == 0) {
|
||||
return null;
|
||||
}
|
||||
OffHeapFloatVectorValues values = this.copy();
|
||||
DocIndexIterator iterator = values.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.similarityFunction.compare(values.vectorValue(), target);
|
||||
return values.similarityFunction.compare(values.vectorValue(iterator.index()), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
public DocIndexIterator iterator() {
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -23,12 +23,12 @@ import java.io.IOException;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.SparseFixedBitSet;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
|
||||
|
@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
float[] query,
|
||||
int topK,
|
||||
int numSeed,
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
FloatVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graphValues,
|
||||
Bits acceptOrds,
|
||||
|
|
|
@ -46,7 +46,6 @@ import org.apache.lucene.util.IOUtils;
|
|||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
/**
|
||||
|
@ -398,8 +397,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
static class OffHeapFloatVectorValues extends FloatVectorValues
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
static class OffHeapFloatVectorValues extends FloatVectorValues {
|
||||
|
||||
private final int dimension;
|
||||
private final int size;
|
||||
|
@ -410,9 +408,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
private final float[] value;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
private int ord = -1;
|
||||
private int doc = -1;
|
||||
|
||||
OffHeapFloatVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -439,49 +434,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
dataIn.seek((long) ord * byteSize);
|
||||
dataIn.readFloats(value, 0, value.length);
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
if (++ord >= size) {
|
||||
doc = NO_MORE_DOCS;
|
||||
} else {
|
||||
doc = ordToDocOperator.applyAsInt(ord);
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
assert docID() < target;
|
||||
|
||||
if (ordToDoc == null) {
|
||||
ord = target;
|
||||
} else {
|
||||
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
|
||||
if (ord < 0) {
|
||||
ord = -(ord + 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (ord < size) {
|
||||
doc = ordToDocOperator.applyAsInt(ord);
|
||||
} else {
|
||||
doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() {
|
||||
return new OffHeapFloatVectorValues(
|
||||
|
@ -495,21 +447,32 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return ordToDocOperator.applyAsInt(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createSparseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
if (size == 0) {
|
||||
return null;
|
||||
}
|
||||
OffHeapFloatVectorValues values = this.copy();
|
||||
DocIndexIterator iterator = values.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.similarityFunction.compare(values.vectorValue(), target);
|
||||
return values.similarityFunction.compare(values.vectorValue(iterator.index()), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -26,12 +26,10 @@ import org.apache.lucene.search.VectorScorer;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
abstract class OffHeapFloatVectorValues extends FloatVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -95,8 +93,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -105,35 +101,16 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
super(dimension, size, vectorSimilarityFunction, slice);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
|
@ -142,15 +119,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
DenseOffHeapVectorValues values = this.copy();
|
||||
DocIndexIterator iterator = values.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
return values.vectorSimilarityFunction.compare(
|
||||
values.vectorValue(iterator.index()), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -186,33 +165,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
fieldEntry.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, dataIn, vectorSimilarityFunction, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return IndexedDISI.asDocIndexIterator(disi);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return (int) ordToDoc.get(ord);
|
||||
|
@ -239,15 +202,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
SparseOffHeapVectorValues values = this.copy();
|
||||
DocIndexIterator iterator = values.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
return values.vectorSimilarityFunction.compare(
|
||||
values.vectorValue(iterator.index()), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -259,8 +224,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
super(dimension, 0, VectorSimilarityFunction.COSINE, null);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
|
@ -271,26 +234,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -28,12 +28,10 @@ import org.apache.lucene.search.VectorScorer;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues.Bytes {
|
||||
abstract class OffHeapByteVectorValues extends ByteVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -108,8 +106,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
|
||||
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -119,36 +115,17 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
|
@ -157,15 +134,16 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
DenseOffHeapVectorValues copy = this.copy();
|
||||
DocIndexIterator iterator = copy.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
|
||||
return vectorSimilarityFunction.compare(copy.vectorValue(iterator.index()), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -202,27 +180,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
fieldEntry.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
|
@ -234,6 +191,11 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
return (int) ordToDoc.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return fromDISI(disi);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
|
@ -255,15 +217,16 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
SparseOffHeapVectorValues copy = this.copy();
|
||||
IndexedDISI disi = copy.disi;
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorSimilarityFunction.compare(copy.vectorValue(), query);
|
||||
return vectorSimilarityFunction.compare(copy.vectorValue(disi.index()), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return disi;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -275,8 +238,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
|
@ -287,26 +248,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OffHeapByteVectorValues copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -26,12 +26,10 @@ import org.apache.lucene.search.VectorScorer;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
abstract class OffHeapFloatVectorValues extends FloatVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -104,8 +102,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
|
||||
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -115,36 +111,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
super(dimension, size, slice, vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
|
@ -153,15 +130,18 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
DenseOffHeapVectorValues values = this.copy();
|
||||
DocIndexIterator iterator = values.iterator();
|
||||
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
return values.vectorSimilarityFunction.compare(
|
||||
values.vectorValue(iterator.index()), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -198,33 +178,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
fieldEntry.size());
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return IndexedDISI.asDocIndexIterator(disi);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return (int) ordToDoc.get(ord);
|
||||
|
@ -251,15 +215,17 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
SparseOffHeapVectorValues values = this.copy();
|
||||
DocIndexIterator iterator = values.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return values.vectorSimilarityFunction.compare(values.vectorValue(), query);
|
||||
return values.vectorSimilarityFunction.compare(
|
||||
values.vectorValue(iterator.index()), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return values;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -271,8 +237,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
|
@ -283,26 +247,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public OffHeapFloatVectorValues copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -29,13 +29,13 @@ import org.apache.lucene.index.ByteVectorValues;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Writes vector values and knn graphs to index segments.
|
||||
|
@ -188,12 +188,13 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
int count = 0;
|
||||
ByteBuffer binaryVector =
|
||||
ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) {
|
||||
KnnVectorValues.DocIndexIterator iter = vectors.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
float[] vectorValue = vectors.vectorValue();
|
||||
float[] vectorValue = vectors.vectorValue(iter.index());
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docIds[count] = docV;
|
||||
docIds[count++] = docV;
|
||||
}
|
||||
|
||||
if (docIds.length > count) {
|
||||
|
@ -234,7 +235,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
|
||||
private void writeGraph(
|
||||
IndexOutput graphData,
|
||||
RandomAccessVectorValues.Floats vectorValues,
|
||||
FloatVectorValues vectorValues,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
long graphDataOffset,
|
||||
long[] offsets,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
* limIndexedDISIitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.backward_codecs.lucene90;
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.Objects;
|
|||
import java.util.SplittableRandom;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
@ -32,7 +33,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
|||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
/**
|
||||
|
@ -57,7 +57,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
|
||||
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues.Floats vectorValues;
|
||||
private final FloatVectorValues vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final Lucene91BoundsChecker bound;
|
||||
private final HnswGraphSearcher graphSearcher;
|
||||
|
@ -68,7 +68,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private RandomAccessVectorValues.Floats buildVectors;
|
||||
private FloatVectorValues buildVectors;
|
||||
|
||||
/**
|
||||
* Reads all the vectors from vector values, builds a graph connecting them by their dense
|
||||
|
@ -83,7 +83,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene91HnswGraphBuilder(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
FloatVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
|
@ -113,14 +113,14 @@ public final class Lucene91HnswGraphBuilder {
|
|||
}
|
||||
|
||||
/**
|
||||
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
|
||||
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies
|
||||
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* returned values.
|
||||
*
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
|
||||
public Lucene91OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
@ -254,7 +254,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
float[] candidate,
|
||||
float score,
|
||||
Lucene91NeighborArray neighbors,
|
||||
RandomAccessVectorValues.Floats vectorValues)
|
||||
FloatVectorValues vectorValues)
|
||||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
|
||||
package org.apache.lucene.backward_codecs.lucene91;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
|
@ -30,6 +28,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -37,7 +36,6 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Writes vector values and knn graphs to index segments.
|
||||
|
@ -183,9 +181,10 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
ByteBuffer binaryVector =
|
||||
ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = vectors.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != DocIdSetIterator.NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
float[] vectorValue = vectors.vectorValue();
|
||||
float[] vectorValue = vectors.vectorValue(iter.index());
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docsWithField.add(docV);
|
||||
|
@ -243,7 +242,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private Lucene91OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.backward_codecs.lucene92;
|
||||
|
||||
import static org.apache.lucene.backward_codecs.lucene92.Lucene92RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
|
@ -33,6 +32,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
|
@ -43,7 +43,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
|||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||
|
||||
|
@ -190,9 +189,12 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
ByteBuffer binaryVector =
|
||||
ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iterator = vectors.iterator();
|
||||
for (int docV = iterator.nextDoc();
|
||||
docV != DocIdSetIterator.NO_MORE_DOCS;
|
||||
docV = iterator.nextDoc()) {
|
||||
// write vector
|
||||
float[] vectorValue = vectors.vectorValue();
|
||||
float[] vectorValue = vectors.vectorValue(iterator.index());
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docsWithField.add(docV);
|
||||
|
@ -277,7 +279,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
|
||||
// build graph
|
||||
|
|
|
@ -36,6 +36,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
|
@ -52,7 +53,6 @@ import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
|||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||
|
||||
|
@ -216,9 +216,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
final int[] docIdOffsets = new int[sortMap.size()];
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
|
||||
for (int docID = iterator.nextDoc();
|
||||
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||
docID = iterator.nextDoc()) {
|
||||
for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
|
@ -556,9 +554,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
final DirectMonotonicWriter ordToDocWriter =
|
||||
DirectMonotonicWriter.getInstance(meta, vectorData, count, DIRECT_MONOTONIC_BLOCK_SHIFT);
|
||||
DocIdSetIterator iterator = docsWithField.iterator();
|
||||
for (int doc = iterator.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = iterator.nextDoc()) {
|
||||
for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) {
|
||||
ordToDocWriter.add(doc);
|
||||
}
|
||||
ordToDocWriter.finish();
|
||||
|
@ -590,11 +586,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private static DocsWithFieldSet writeByteVectorData(
|
||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||
byte[] binaryValue = byteVectorValues.vectorValue(iter.index());
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
|
@ -608,14 +603,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private static DocsWithFieldSet writeVectorData(
|
||||
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
|
||||
ByteBuffer binaryVector =
|
||||
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
float[] vectorValue = floatVectorValues.vectorValue();
|
||||
float[] vectorValue = floatVectorValues.vectorValue(iter.index());
|
||||
binaryVector.asFloatBuffer().put(vectorValue);
|
||||
output.writeBytes(binaryVector.array(), binaryVector.limit());
|
||||
docsWithField.add(docV);
|
||||
|
@ -672,11 +666,11 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE ->
|
||||
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromBytes((List<byte[]>) vectors, dim));
|
||||
ByteVectorValues.fromBytes((List<byte[]>) vectors, dim));
|
||||
case FLOAT32 ->
|
||||
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromFloats((List<float[]>) vectors, dim));
|
||||
FloatVectorValues.fromFloats((List<float[]>) vectors, dim));
|
||||
};
|
||||
hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
|
@ -56,7 +57,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
|||
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||
|
||||
|
@ -221,9 +221,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
final int[] docIdOffsets = new int[sortMap.size()];
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
|
||||
for (int docID = iterator.nextDoc();
|
||||
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||
docID = iterator.nextDoc()) {
|
||||
for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
|
@ -482,18 +480,18 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||
}
|
||||
}
|
||||
DocIdSetIterator mergedVectorIterator = null;
|
||||
KnnVectorValues mergedVectorValues = null;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE ->
|
||||
mergedVectorIterator =
|
||||
mergedVectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
case FLOAT32 ->
|
||||
mergedVectorIterator =
|
||||
mergedVectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
graph =
|
||||
merger.merge(
|
||||
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
|
||||
mergedVectorValues, segmentWriteState.infoStream, docsWithField.cardinality());
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
@ -636,14 +634,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private static DocsWithFieldSet writeByteVectorData(
|
||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
|
||||
for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) {
|
||||
// write vector
|
||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||
byte[] binaryValue = byteVectorValues.vectorValue(iter.index());
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
docsWithField.add(docId);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
@ -657,11 +654,10 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
ByteBuffer buffer =
|
||||
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
float[] value = floatVectorValues.vectorValue();
|
||||
float[] value = floatVectorValues.vectorValue(iter.index());
|
||||
buffer.asFloatBuffer().put(value);
|
||||
output.writeBytes(buffer.array(), buffer.limit());
|
||||
docsWithField.add(docV);
|
||||
|
@ -718,11 +714,11 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE ->
|
||||
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromBytes((List<byte[]>) vectors, dim));
|
||||
ByteVectorValues.fromBytes((List<byte[]>) vectors, dim));
|
||||
case FLOAT32 ->
|
||||
defaultFlatVectorScorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromFloats((List<float[]>) vectors, dim));
|
||||
FloatVectorValues.fromFloats((List<float[]>) vectors, dim));
|
||||
};
|
||||
hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
|
|
|
@ -52,6 +52,7 @@ import org.apache.lucene.index.IndexReader;
|
|||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.LogByteSizeMergePolicy;
|
||||
import org.apache.lucene.index.MultiBits;
|
||||
|
@ -477,10 +478,14 @@ public class TestBasicBackwardsCompatibility extends BackwardsCompatibilityTestB
|
|||
FloatVectorValues values = ctx.reader().getFloatVectorValues(KNN_VECTOR_FIELD);
|
||||
if (values != null) {
|
||||
assertEquals(KNN_VECTOR_FIELD_TYPE.vectorDimension(), values.dimension());
|
||||
for (int doc = values.nextDoc(); doc != NO_MORE_DOCS; doc = values.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator it = values.iterator();
|
||||
for (int doc = it.nextDoc(); doc != NO_MORE_DOCS; doc = it.nextDoc()) {
|
||||
float[] expectedVector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * cnt};
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + cnt, expectedVector, values.vectorValue(), 0);
|
||||
"vectors do not match for doc=" + cnt,
|
||||
expectedVector,
|
||||
values.vectorValue(it.index()),
|
||||
0);
|
||||
cnt++;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit;
|
|||
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
|
@ -32,7 +33,6 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.openjdk.jmh.annotations.*;
|
||||
|
@ -55,7 +55,7 @@ public class VectorScorerBenchmark {
|
|||
|
||||
Directory dir;
|
||||
IndexInput in;
|
||||
RandomAccessVectorValues vectorValues;
|
||||
KnnVectorValues vectorValues;
|
||||
byte[] vec1, vec2;
|
||||
RandomVectorScorer scorer;
|
||||
|
||||
|
@ -95,7 +95,7 @@ public class VectorScorerBenchmark {
|
|||
return scorer.score(1);
|
||||
}
|
||||
|
||||
static RandomAccessVectorValues vectorValues(
|
||||
static KnnVectorValues vectorValues(
|
||||
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
|
||||
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim);
|
||||
|
@ -105,23 +105,19 @@ public class VectorScorerBenchmark {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) {
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target) {
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target) {
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,10 +19,11 @@ package org.apache.lucene.codecs.bitvectors;
|
|||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
|
@ -30,45 +31,39 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
|||
public class FlatBitVectorsScorer implements FlatVectorsScorer {
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
|
||||
throws IOException {
|
||||
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
|
||||
assert vectorValues instanceof ByteVectorValues;
|
||||
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
|
||||
return new BitRandomVectorScorerSupplier(byteVectorValues);
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"vectorValues must be an instance of RandomAccessVectorValues.Bytes");
|
||||
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
|
||||
throws IOException {
|
||||
throw new IllegalArgumentException("bit vectors do not support float[] targets");
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
|
||||
throws IOException {
|
||||
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
|
||||
assert vectorValues instanceof ByteVectorValues;
|
||||
if (vectorValues instanceof ByteVectorValues byteVectorValues) {
|
||||
return new BitRandomVectorScorer(byteVectorValues, target);
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"vectorValues must be an instance of RandomAccessVectorValues.Bytes");
|
||||
throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues");
|
||||
}
|
||||
|
||||
static class BitRandomVectorScorer implements RandomVectorScorer {
|
||||
private final RandomAccessVectorValues.Bytes vectorValues;
|
||||
private final ByteVectorValues vectorValues;
|
||||
private final int bitDimensions;
|
||||
private final byte[] query;
|
||||
|
||||
BitRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) {
|
||||
BitRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) {
|
||||
this.query = query;
|
||||
this.bitDimensions = vectorValues.dimension() * Byte.SIZE;
|
||||
this.vectorValues = vectorValues;
|
||||
|
@ -97,12 +92,11 @@ public class FlatBitVectorsScorer implements FlatVectorsScorer {
|
|||
}
|
||||
|
||||
static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
|
||||
protected final RandomAccessVectorValues.Bytes vectorValues;
|
||||
protected final RandomAccessVectorValues.Bytes vectorValues1;
|
||||
protected final RandomAccessVectorValues.Bytes vectorValues2;
|
||||
protected final ByteVectorValues vectorValues;
|
||||
protected final ByteVectorValues vectorValues1;
|
||||
protected final ByteVectorValues vectorValues2;
|
||||
|
||||
public BitRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues)
|
||||
throws IOException {
|
||||
public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException {
|
||||
this.vectorValues = vectorValues;
|
||||
this.vectorValues1 = vectorValues.copy();
|
||||
this.vectorValues2 = vectorValues.copy();
|
||||
|
|
|
@ -192,8 +192,8 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||
int doc;
|
||||
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
for (int ord = 0; ord < values.size(); ord++) {
|
||||
int doc = values.ordToDoc(ord);
|
||||
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||
continue;
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
break;
|
||||
}
|
||||
|
||||
float[] vector = values.vectorValue();
|
||||
float[] vector = values.vectorValue(ord);
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
knnCollector.collect(doc, score);
|
||||
knnCollector.incVisitedCount(1);
|
||||
|
@ -223,8 +223,8 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||
|
||||
int doc;
|
||||
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
for (int ord = 0; ord < values.size(); ord++) {
|
||||
int doc = values.ordToDoc(ord);
|
||||
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||
continue;
|
||||
}
|
||||
|
@ -233,7 +233,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
break;
|
||||
}
|
||||
|
||||
byte[] vector = values.vectorValue();
|
||||
byte[] vector = values.vectorValue(ord);
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
knnCollector.collect(doc, score);
|
||||
knnCollector.incVisitedCount(1);
|
||||
|
@ -327,35 +327,18 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
return values[curOrd];
|
||||
public float[] vectorValue(int ord) {
|
||||
return values[ord];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curOrd == -1) {
|
||||
return -1;
|
||||
} else if (curOrd >= entry.size()) {
|
||||
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
|
||||
// immediately afterward should also return NO_MORE_DOCS
|
||||
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
return entry.ordToDoc[curOrd];
|
||||
public int ordToDoc(int ord) {
|
||||
return entry.ordToDoc[ord];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
if (++curOrd < entry.size()) {
|
||||
return docID();
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
public DocIndexIterator iterator() {
|
||||
return createSparseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -365,17 +348,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
SimpleTextFloatVectorValues simpleTextFloatVectorValues =
|
||||
new SimpleTextFloatVectorValues(this);
|
||||
DocIndexIterator iterator = simpleTextFloatVectorValues.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
int ord = iterator.index();
|
||||
return entry
|
||||
.similarityFunction()
|
||||
.compare(simpleTextFloatVectorValues.vectorValue(), target);
|
||||
.compare(simpleTextFloatVectorValues.vectorValue(ord), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return simpleTextFloatVectorValues;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -397,6 +382,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
value[i] = Float.parseFloat(floatStrings[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimpleTextFloatVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
private static class SimpleTextByteVectorValues extends ByteVectorValues {
|
||||
|
@ -439,36 +429,14 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
binaryValue.bytes = values[curOrd];
|
||||
public byte[] vectorValue(int ord) {
|
||||
binaryValue.bytes = values[ord];
|
||||
return binaryValue.bytes;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curOrd == -1) {
|
||||
return -1;
|
||||
} else if (curOrd >= entry.size()) {
|
||||
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
|
||||
// immediately afterward should also return NO_MORE_DOCS
|
||||
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
return entry.ordToDoc[curOrd];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
if (++curOrd < entry.size()) {
|
||||
return docID();
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
public DocIndexIterator iterator() {
|
||||
return createSparseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -478,16 +446,19 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this);
|
||||
return new VectorScorer() {
|
||||
DocIndexIterator it = simpleTextByteVectorValues.iterator();
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
int ord = it.index();
|
||||
return entry
|
||||
.similarityFunction()
|
||||
.compare(simpleTextByteVectorValues.vectorValue(), target);
|
||||
.compare(simpleTextByteVectorValues.vectorValue(ord), target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return simpleTextByteVectorValues;
|
||||
return it;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -509,6 +480,11 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
value[i] = (byte) Float.parseFloat(floatStrings[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SimpleTextByteVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
private int readInt(IndexInput in, BytesRef field) throws IOException {
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.lucene.index.ByteVectorValues;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -77,19 +78,18 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
throws IOException {
|
||||
long vectorDataOffset = vectorData.getFilePointer();
|
||||
List<Integer> docIds = new ArrayList<>();
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
writeFloatVectorValue(floatVectorValues);
|
||||
docIds.add(docV);
|
||||
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
|
||||
for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) {
|
||||
writeFloatVectorValue(floatVectorValues, iter.index());
|
||||
docIds.add(docId);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
|
||||
}
|
||||
|
||||
private void writeFloatVectorValue(FloatVectorValues vectors) throws IOException {
|
||||
private void writeFloatVectorValue(FloatVectorValues vectors, int ord) throws IOException {
|
||||
// write vector value
|
||||
float[] value = vectors.vectorValue();
|
||||
float[] value = vectors.vectorValue(ord);
|
||||
assert value.length == vectors.dimension();
|
||||
write(vectorData, Arrays.toString(value));
|
||||
newline(vectorData);
|
||||
|
@ -100,19 +100,18 @@ public class SimpleTextKnnVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
throws IOException {
|
||||
long vectorDataOffset = vectorData.getFilePointer();
|
||||
List<Integer> docIds = new ArrayList<>();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
writeByteVectorValue(byteVectorValues);
|
||||
KnnVectorValues.DocIndexIterator it = byteVectorValues.iterator();
|
||||
for (int docV = it.nextDoc(); docV != NO_MORE_DOCS; docV = it.nextDoc()) {
|
||||
writeByteVectorValue(byteVectorValues, it.index());
|
||||
docIds.add(docV);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds);
|
||||
}
|
||||
|
||||
private void writeByteVectorValue(ByteVectorValues vectors) throws IOException {
|
||||
private void writeByteVectorValue(ByteVectorValues vectors, int ord) throws IOException {
|
||||
// write vector value
|
||||
byte[] value = vectors.vectorValue();
|
||||
byte[] value = vectors.vectorValue(ord);
|
||||
assert value.length == vectors.dimension();
|
||||
write(vectorData, Arrays.toString(value));
|
||||
newline(vectorData);
|
||||
|
|
|
@ -20,14 +20,16 @@ package org.apache.lucene.codecs;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Supplier;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.index.SortingCodecReader;
|
||||
import org.apache.lucene.index.SortingCodecReader.SortingValuesIterator;
|
||||
import org.apache.lucene.search.DocIdSet;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
|
||||
|
@ -80,24 +82,26 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
case FLOAT32:
|
||||
BufferedFloatVectorValues bufferedFloatVectorValues =
|
||||
new BufferedFloatVectorValues(
|
||||
fieldData.docsWithField,
|
||||
(List<float[]>) fieldData.vectors,
|
||||
fieldData.fieldInfo.getVectorDimension());
|
||||
fieldData.fieldInfo.getVectorDimension(),
|
||||
fieldData.docsWithField);
|
||||
FloatVectorValues floatVectorValues =
|
||||
sortMap != null
|
||||
? new SortingFloatVectorValues(bufferedFloatVectorValues, sortMap)
|
||||
? new SortingFloatVectorValues(
|
||||
bufferedFloatVectorValues, fieldData.docsWithField, sortMap)
|
||||
: bufferedFloatVectorValues;
|
||||
writeField(fieldData.fieldInfo, floatVectorValues, maxDoc);
|
||||
break;
|
||||
case BYTE:
|
||||
BufferedByteVectorValues bufferedByteVectorValues =
|
||||
new BufferedByteVectorValues(
|
||||
fieldData.docsWithField,
|
||||
(List<byte[]>) fieldData.vectors,
|
||||
fieldData.fieldInfo.getVectorDimension());
|
||||
fieldData.fieldInfo.getVectorDimension(),
|
||||
fieldData.docsWithField);
|
||||
ByteVectorValues byteVectorValues =
|
||||
sortMap != null
|
||||
? new SortingByteVectorValues(bufferedByteVectorValues, sortMap)
|
||||
? new SortingByteVectorValues(
|
||||
bufferedByteVectorValues, fieldData.docsWithField, sortMap)
|
||||
: bufferedByteVectorValues;
|
||||
writeField(fieldData.fieldInfo, byteVectorValues, maxDoc);
|
||||
break;
|
||||
|
@ -107,125 +111,77 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
private static class SortingFloatVectorValues extends FloatVectorValues {
|
||||
private final BufferedFloatVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private int docId = -1;
|
||||
private final BufferedFloatVectorValues delegate;
|
||||
private final Supplier<SortingValuesIterator> iteratorSupplier;
|
||||
|
||||
SortingFloatVectorValues(BufferedFloatVectorValues delegate, Sorter.DocMap sortMap)
|
||||
SortingFloatVectorValues(
|
||||
BufferedFloatVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
this.randomAccess = delegate.copy();
|
||||
this.docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
int docID;
|
||||
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
this.delegate = delegate.copy();
|
||||
iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (docId < docIdOffsets.length - 1) {
|
||||
++docId;
|
||||
if (docIdOffsets[docId] != 0) {
|
||||
return docId;
|
||||
}
|
||||
}
|
||||
docId = NO_MORE_DOCS;
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
return delegate.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return randomAccess.dimension();
|
||||
return delegate.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return randomAccess.size();
|
||||
return delegate.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
public SortingFloatVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return iteratorSupplier.get();
|
||||
}
|
||||
}
|
||||
|
||||
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
/** Sorting ByteVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||
private final BufferedByteVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private int docId = -1;
|
||||
private final BufferedByteVectorValues delegate;
|
||||
private final Supplier<SortingValuesIterator> iteratorSupplier;
|
||||
|
||||
SortingByteVectorValues(BufferedByteVectorValues delegate, Sorter.DocMap sortMap)
|
||||
SortingByteVectorValues(
|
||||
BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap)
|
||||
throws IOException {
|
||||
this.randomAccess = delegate.copy();
|
||||
this.docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
int docID;
|
||||
while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
this.delegate = delegate;
|
||||
iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (docId < docIdOffsets.length - 1) {
|
||||
++docId;
|
||||
if (docIdOffsets[docId] != 0) {
|
||||
return docId;
|
||||
}
|
||||
}
|
||||
docId = NO_MORE_DOCS;
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return randomAccess.vectorValue(docIdOffsets[docId] - 1);
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
return delegate.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return randomAccess.dimension();
|
||||
return delegate.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return randomAccess.size();
|
||||
return delegate.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
public SortingByteVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return iteratorSupplier.get();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -296,7 +252,9 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
|
||||
@Override
|
||||
public final long ramBytesUsed() {
|
||||
if (vectors.size() == 0) return 0;
|
||||
if (vectors.isEmpty()) {
|
||||
return 0;
|
||||
}
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ vectors.size()
|
||||
* (long)
|
||||
|
@ -307,25 +265,18 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
private static class BufferedFloatVectorValues extends FloatVectorValues {
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
|
||||
final List<float[]> vectors;
|
||||
final int dimension;
|
||||
private final DocIdSet docsWithField;
|
||||
private final DocIndexIterator iterator;
|
||||
|
||||
DocIdSetIterator docsWithFieldIter;
|
||||
int ord = -1;
|
||||
|
||||
BufferedFloatVectorValues(
|
||||
DocsWithFieldSet docsWithField, List<float[]> vectors, int dimension) {
|
||||
this.docsWithField = docsWithField;
|
||||
BufferedFloatVectorValues(List<float[]> vectors, int dimension, DocIdSet docsWithField)
|
||||
throws IOException {
|
||||
this.vectors = vectors;
|
||||
this.dimension = dimension;
|
||||
docsWithFieldIter = docsWithField.iterator();
|
||||
}
|
||||
|
||||
public BufferedFloatVectorValues copy() {
|
||||
return new BufferedFloatVectorValues(docsWithField, vectors, dimension);
|
||||
this.docsWithField = docsWithField;
|
||||
this.iterator = fromDISI(docsWithField.iterator());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -339,58 +290,39 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
return vectors.get(ord);
|
||||
public int ordToDoc(int ord) {
|
||||
return ord;
|
||||
}
|
||||
|
||||
float[] vectorValue(int targetOrd) {
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docsWithFieldIter.docID();
|
||||
public DocIndexIterator iterator() {
|
||||
return iterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int docID = docsWithFieldIter.nextDoc();
|
||||
if (docID != NO_MORE_DOCS) {
|
||||
++ord;
|
||||
}
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public BufferedFloatVectorValues copy() throws IOException {
|
||||
return new BufferedFloatVectorValues(vectors, dimension, docsWithField);
|
||||
}
|
||||
}
|
||||
|
||||
private static class BufferedByteVectorValues extends ByteVectorValues {
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
// These are always the vectors of a VectorValuesWriter, which are copied when added to it
|
||||
final List<byte[]> vectors;
|
||||
final int dimension;
|
||||
private final DocIdSet docsWithField;
|
||||
private final DocIndexIterator iterator;
|
||||
|
||||
DocIdSetIterator docsWithFieldIter;
|
||||
int ord = -1;
|
||||
|
||||
BufferedByteVectorValues(DocsWithFieldSet docsWithField, List<byte[]> vectors, int dimension) {
|
||||
this.docsWithField = docsWithField;
|
||||
BufferedByteVectorValues(List<byte[]> vectors, int dimension, DocIdSet docsWithField)
|
||||
throws IOException {
|
||||
this.vectors = vectors;
|
||||
this.dimension = dimension;
|
||||
docsWithFieldIter = docsWithField.iterator();
|
||||
}
|
||||
|
||||
public BufferedByteVectorValues copy() {
|
||||
return new BufferedByteVectorValues(docsWithField, vectors, dimension);
|
||||
this.docsWithField = docsWithField;
|
||||
iterator = fromDISI(docsWithField.iterator());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -404,36 +336,18 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
return vectors.get(ord);
|
||||
}
|
||||
|
||||
byte[] vectorValue(int targetOrd) {
|
||||
public byte[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docsWithFieldIter.docID();
|
||||
public DocIndexIterator iterator() {
|
||||
return iterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int docID = docsWithFieldIter.nextDoc();
|
||||
if (docID != NO_MORE_DOCS) {
|
||||
++ord;
|
||||
}
|
||||
return docID;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public BufferedByteVectorValues copy() throws IOException {
|
||||
return new BufferedByteVectorValues(vectors, dimension, docsWithField);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
|
@ -55,28 +56,26 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
case BYTE -> {
|
||||
KnnFieldVectorsWriter<byte[]> byteWriter =
|
||||
(KnnFieldVectorsWriter<byte[]>) addField(fieldInfo);
|
||||
ByteVectorValues mergedBytes =
|
||||
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedBytes.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedBytes.nextDoc()) {
|
||||
byteWriter.addValue(doc, mergedBytes.vectorValue());
|
||||
KnnVectorValues.DocIndexIterator iter = mergedBytes.iterator();
|
||||
for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) {
|
||||
byteWriter.addValue(doc, mergedBytes.vectorValue(iter.index()));
|
||||
}
|
||||
break;
|
||||
case FLOAT32:
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
KnnFieldVectorsWriter<float[]> floatWriter =
|
||||
(KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
|
||||
FloatVectorValues mergedFloats =
|
||||
MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedFloats.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedFloats.nextDoc()) {
|
||||
floatWriter.addValue(doc, mergedFloats.vectorValue());
|
||||
KnnVectorValues.DocIndexIterator iter = mergedFloats.iterator();
|
||||
for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) {
|
||||
floatWriter.addValue(doc, mergedFloats.vectorValue(iter.index()));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,32 +116,44 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
private static class FloatVectorValuesSub extends DocIDMerger.Sub {
|
||||
|
||||
final FloatVectorValues values;
|
||||
final KnnVectorValues.DocIndexIterator iterator;
|
||||
|
||||
FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) {
|
||||
super(docMap);
|
||||
this.values = values;
|
||||
assert values.docID() == -1;
|
||||
this.iterator = values.iterator();
|
||||
assert iterator.docID() == -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return values.nextDoc();
|
||||
return iterator.nextDoc();
|
||||
}
|
||||
|
||||
public int index() {
|
||||
return iterator.index();
|
||||
}
|
||||
}
|
||||
|
||||
private static class ByteVectorValuesSub extends DocIDMerger.Sub {
|
||||
|
||||
final ByteVectorValues values;
|
||||
final KnnVectorValues.DocIndexIterator iterator;
|
||||
|
||||
ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) {
|
||||
super(docMap);
|
||||
this.values = values;
|
||||
assert values.docID() == -1;
|
||||
iterator = values.iterator();
|
||||
assert iterator.docID() == -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return values.nextDoc();
|
||||
return iterator.nextDoc();
|
||||
}
|
||||
|
||||
int index() {
|
||||
return iterator.index();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -287,7 +298,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
private final List<FloatVectorValuesSub> subs;
|
||||
private final DocIDMerger<FloatVectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
private int docId;
|
||||
private int docId = -1;
|
||||
private int lastOrd = -1;
|
||||
FloatVectorValuesSub current;
|
||||
|
||||
private MergedFloat32VectorValues(List<FloatVectorValuesSub> subs, MergeState mergeState)
|
||||
|
@ -299,33 +311,57 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return new DocIndexIterator() {
|
||||
private int index = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return index;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
index = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
++index;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public long cost() {
|
||||
return size;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
if (ord != lastOrd + 1) {
|
||||
throw new IllegalStateException(
|
||||
"only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd);
|
||||
} else {
|
||||
lastOrd = ord;
|
||||
}
|
||||
return current.values.vectorValue(current.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -338,10 +374,20 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
return subs.get(0).values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
static class MergedByteVectorValues extends ByteVectorValues {
|
||||
|
@ -349,7 +395,8 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
|
||||
private int docId;
|
||||
private int lastOrd = -1;
|
||||
private int docId = -1;
|
||||
ByteVectorValuesSub current;
|
||||
|
||||
private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
|
||||
|
@ -361,35 +408,59 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
if (ord != lastOrd + 1) {
|
||||
throw new IllegalStateException(
|
||||
"only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd);
|
||||
} else {
|
||||
lastOrd = ord;
|
||||
}
|
||||
return current.values.vectorValue(current.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return new DocIndexIterator() {
|
||||
private int index = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return index;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
index = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
++index;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return size;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
|
@ -400,10 +471,20 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
return subs.get(0).values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
|
@ -34,24 +36,26 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
|
||||
throws IOException {
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Floats floatVectorValues) {
|
||||
return new FloatScoringSupplier(floatVectorValues, similarityFunction);
|
||||
} else if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) {
|
||||
return new ByteScoringSupplier(byteVectorValues, similarityFunction);
|
||||
switch (vectorValues.getEncoding()) {
|
||||
case FLOAT32 -> {
|
||||
return new FloatScoringSupplier((FloatVectorValues) vectorValues, similarityFunction);
|
||||
}
|
||||
case BYTE -> {
|
||||
return new ByteScoringSupplier((ByteVectorValues) vectorValues, similarityFunction);
|
||||
}
|
||||
}
|
||||
throw new IllegalArgumentException(
|
||||
"vectorValues must be an instance of RandomAccessVectorValues.Floats or RandomAccessVectorValues.Bytes");
|
||||
"vectorValues must be an instance of FloatVectorValues or ByteVectorValues, got a "
|
||||
+ vectorValues.getClass().getName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
|
||||
throws IOException {
|
||||
assert vectorValues instanceof RandomAccessVectorValues.Floats;
|
||||
assert vectorValues instanceof FloatVectorValues;
|
||||
if (target.length != vectorValues.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: "
|
||||
|
@ -59,17 +63,14 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
+ " differs from field dimension: "
|
||||
+ vectorValues.dimension());
|
||||
}
|
||||
return new FloatVectorScorer(
|
||||
(RandomAccessVectorValues.Floats) vectorValues, target, similarityFunction);
|
||||
return new FloatVectorScorer((FloatVectorValues) vectorValues, target, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
|
||||
throws IOException {
|
||||
assert vectorValues instanceof RandomAccessVectorValues.Bytes;
|
||||
assert vectorValues instanceof ByteVectorValues;
|
||||
if (target.length != vectorValues.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: "
|
||||
|
@ -77,8 +78,7 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
+ " differs from field dimension: "
|
||||
+ vectorValues.dimension());
|
||||
}
|
||||
return new ByteVectorScorer(
|
||||
(RandomAccessVectorValues.Bytes) vectorValues, target, similarityFunction);
|
||||
return new ByteVectorScorer((ByteVectorValues) vectorValues, target, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -88,14 +88,13 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
/** RandomVectorScorerSupplier for bytes vector */
|
||||
private static final class ByteScoringSupplier implements RandomVectorScorerSupplier {
|
||||
private final RandomAccessVectorValues.Bytes vectors;
|
||||
private final RandomAccessVectorValues.Bytes vectors1;
|
||||
private final RandomAccessVectorValues.Bytes vectors2;
|
||||
private final ByteVectorValues vectors;
|
||||
private final ByteVectorValues vectors1;
|
||||
private final ByteVectorValues vectors2;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
private ByteScoringSupplier(
|
||||
RandomAccessVectorValues.Bytes vectors, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
|
||||
this.vectors = vectors;
|
||||
vectors1 = vectors.copy();
|
||||
vectors2 = vectors.copy();
|
||||
|
@ -125,14 +124,13 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
/** RandomVectorScorerSupplier for Float vector */
|
||||
private static final class FloatScoringSupplier implements RandomVectorScorerSupplier {
|
||||
private final RandomAccessVectorValues.Floats vectors;
|
||||
private final RandomAccessVectorValues.Floats vectors1;
|
||||
private final RandomAccessVectorValues.Floats vectors2;
|
||||
private final FloatVectorValues vectors;
|
||||
private final FloatVectorValues vectors1;
|
||||
private final FloatVectorValues vectors2;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
private FloatScoringSupplier(
|
||||
RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException {
|
||||
this.vectors = vectors;
|
||||
vectors1 = vectors.copy();
|
||||
vectors2 = vectors.copy();
|
||||
|
@ -162,14 +160,12 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
/** A {@link RandomVectorScorer} for float vectors. */
|
||||
private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final RandomAccessVectorValues.Floats values;
|
||||
private final FloatVectorValues values;
|
||||
private final float[] query;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
public FloatVectorScorer(
|
||||
RandomAccessVectorValues.Floats values,
|
||||
float[] query,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
FloatVectorValues values, float[] query, VectorSimilarityFunction similarityFunction) {
|
||||
super(values);
|
||||
this.values = values;
|
||||
this.query = query;
|
||||
|
@ -184,14 +180,12 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
/** A {@link RandomVectorScorer} for byte vectors. */
|
||||
private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final RandomAccessVectorValues.Bytes values;
|
||||
private final ByteVectorValues values;
|
||||
private final byte[] query;
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
|
||||
public ByteVectorScorer(
|
||||
RandomAccessVectorValues.Bytes values,
|
||||
byte[] query,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
ByteVectorValues values, byte[] query, VectorSimilarityFunction similarityFunction) {
|
||||
super(values);
|
||||
this.values = values;
|
||||
this.query = query;
|
||||
|
|
|
@ -18,8 +18,8 @@
|
|||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
|
@ -40,7 +40,19 @@ public interface FlatVectorsScorer {
|
|||
* @throws IOException if an I/O error occurs
|
||||
*/
|
||||
RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns a {@link RandomVectorScorer} for the given set of vectors and target vector.
|
||||
*
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @param vectorValues the vector values to score
|
||||
* @param target the target vector
|
||||
* @return a {@link RandomVectorScorer} for the given field and target vector.
|
||||
* @throws IOException if an I/O error occurs when reading from the index.
|
||||
*/
|
||||
RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
|
@ -53,23 +65,6 @@ public interface FlatVectorsScorer {
|
|||
* @throws IOException if an I/O error occurs when reading from the index.
|
||||
*/
|
||||
RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Returns a {@link RandomVectorScorer} for the given set of vectors and target vector.
|
||||
*
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @param vectorValues the vector values to score
|
||||
* @param target the target vector
|
||||
* @return a {@link RandomVectorScorer} for the given field and target vector.
|
||||
* @throws IOException if an I/O error occurs when reading from the index.
|
||||
*/
|
||||
RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
|
||||
throws IOException;
|
||||
}
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
|
||||
|
@ -60,9 +60,9 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
|
||||
throws IOException {
|
||||
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
return new ScalarQuantizedRandomVectorScorerSupplier(
|
||||
similarityFunction,
|
||||
quantizedByteVectorValues.getScalarQuantizer(),
|
||||
|
@ -74,11 +74,9 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
|
||||
throws IOException {
|
||||
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
|
||||
byte[] targetBytes = new byte[target.length];
|
||||
float offsetCorrection =
|
||||
|
@ -104,9 +102,7 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
|
||||
throws IOException {
|
||||
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
|
||||
}
|
||||
|
@ -124,14 +120,14 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
public static class ScalarQuantizedRandomVectorScorerSupplier
|
||||
implements RandomVectorScorerSupplier {
|
||||
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final QuantizedByteVectorValues values;
|
||||
private final ScalarQuantizedVectorSimilarity similarity;
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
||||
public ScalarQuantizedRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
RandomAccessQuantizedByteVectorValues values) {
|
||||
QuantizedByteVectorValues values) {
|
||||
this.similarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
similarityFunction,
|
||||
|
@ -144,7 +140,7 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
private ScalarQuantizedRandomVectorScorerSupplier(
|
||||
ScalarQuantizedVectorSimilarity similarity,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
RandomAccessQuantizedByteVectorValues values) {
|
||||
QuantizedByteVectorValues values) {
|
||||
this.similarity = similarity;
|
||||
this.values = values;
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
|
@ -152,7 +148,7 @@ public class ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
|
||||
final QuantizedByteVectorValues vectorsCopy = values.copy();
|
||||
final byte[] queryVector = values.vectorValue(ord);
|
||||
final float queryOffset = values.getScoreCorrectionConstant(ord);
|
||||
return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) {
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene90;
|
|||
|
||||
import java.io.DataInput;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
|
@ -439,6 +440,40 @@ public final class IndexedDISI extends DocIdSetIterator {
|
|||
// ALL variables
|
||||
int gap;
|
||||
|
||||
/**
|
||||
* Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance the
|
||||
* underlying IndexedDISI, and vice-versa.
|
||||
*/
|
||||
public static KnnVectorValues.DocIndexIterator asDocIndexIterator(IndexedDISI disi) {
|
||||
// can we replace with fromDISI?
|
||||
return new KnnVectorValues.DocIndexIterator() {
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return disi.index();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return disi.cost();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
|
|
|
@ -14,23 +14,16 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util.quantization;
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
||||
/**
|
||||
* Random access values for <code>byte[]</code>, but also includes accessing the score correction
|
||||
* constant for the current vector in the buffer.
|
||||
*
|
||||
* @lucene.experimental
|
||||
* Implementors can return the IndexInput from which their values are read. For use by vector
|
||||
* quantizers.
|
||||
*/
|
||||
public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues.Bytes {
|
||||
public interface HasIndexSlice {
|
||||
|
||||
ScalarQuantizer getScalarQuantizer();
|
||||
|
||||
float getScoreCorrectionConstant(int vectorOrd) throws IOException;
|
||||
|
||||
@Override
|
||||
RandomAccessQuantizedByteVectorValues copy() throws IOException;
|
||||
/** Returns an IndexInput from which to read this instance's values. */
|
||||
IndexInput getSlice();
|
||||
}
|
|
@ -29,13 +29,11 @@ import org.apache.lucene.search.VectorScorer;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues.Bytes {
|
||||
public abstract class OffHeapByteVectorValues extends ByteVectorValues implements HasIndexSlice {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -132,9 +130,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
* vector.
|
||||
*/
|
||||
public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -145,36 +140,17 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
|
@ -183,17 +159,18 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
@Override
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
DenseOffHeapVectorValues copy = copy();
|
||||
DocIndexIterator iterator = copy.iterator();
|
||||
RandomVectorScorer scorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return scorer.score(copy.doc);
|
||||
return scorer.score(iterator.docID());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -238,27 +215,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
configuration.size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
|
@ -276,6 +232,11 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
return (int) ordToDoc.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return IndexedDISI.asDocIndexIterator(disi);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
|
@ -307,7 +268,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return copy.disi;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -322,8 +283,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
|
@ -335,23 +294,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -359,11 +308,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -28,13 +28,11 @@ import org.apache.lucene.search.VectorScorer;
|
|||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
public abstract class OffHeapFloatVectorValues extends FloatVectorValues implements HasIndexSlice {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -128,8 +126,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
*/
|
||||
public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -140,55 +136,42 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
DenseOffHeapVectorValues copy = copy();
|
||||
DocIndexIterator iterator = copy.iterator();
|
||||
RandomVectorScorer randomVectorScorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return randomVectorScorer.score(copy.doc);
|
||||
return randomVectorScorer.score(iterator.docID());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -227,27 +210,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
configuration.size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(
|
||||
|
@ -283,20 +245,26 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return IndexedDISI.asDocIndexIterator(disi);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
SparseOffHeapVectorValues copy = copy();
|
||||
DocIndexIterator iterator = copy.iterator();
|
||||
RandomVectorScorer randomVectorScorer =
|
||||
flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return randomVectorScorer.score(copy.disi.index());
|
||||
return randomVectorScorer.score(iterator.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -311,8 +279,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
|
@ -323,26 +289,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmptyOffHeapVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
|
@ -354,8 +300,8 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -39,6 +39,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
|
@ -361,11 +362,10 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
private static DocsWithFieldSet writeByteVectorData(
|
||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||
byte[] binaryValue = byteVectorValues.vectorValue(iter.index());
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
|
@ -382,11 +382,10 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
ByteBuffer buffer =
|
||||
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
float[] value = floatVectorValues.vectorValue();
|
||||
float[] value = floatVectorValues.vectorValue(iter.index());
|
||||
buffer.asFloatBuffer().put(value);
|
||||
output.writeBytes(buffer.array(), buffer.limit());
|
||||
docsWithField.add(docV);
|
||||
|
|
|
@ -32,14 +32,16 @@ import org.apache.lucene.codecs.KnnVectorsWriter;
|
|||
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
@ -54,7 +56,6 @@ import org.apache.lucene.util.hnsw.HnswGraphMerger;
|
|||
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||
|
||||
|
@ -359,18 +360,18 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||
}
|
||||
}
|
||||
DocIdSetIterator mergedVectorIterator = null;
|
||||
KnnVectorValues mergedVectorValues = null;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE ->
|
||||
mergedVectorIterator =
|
||||
mergedVectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
case FLOAT32 ->
|
||||
mergedVectorIterator =
|
||||
mergedVectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
graph =
|
||||
merger.merge(
|
||||
mergedVectorIterator,
|
||||
mergedVectorValues,
|
||||
segmentWriteState.infoStream,
|
||||
scorerSupplier.totalVectorCount());
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
|
@ -582,13 +583,13 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE ->
|
||||
scorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromBytes(
|
||||
ByteVectorValues.fromBytes(
|
||||
(List<byte[]>) flatFieldVectorsWriter.getVectors(),
|
||||
fieldInfo.getVectorDimension()));
|
||||
case FLOAT32 ->
|
||||
scorer.getRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
RandomAccessVectorValues.fromFloats(
|
||||
FloatVectorValues.fromFloats(
|
||||
(List<float[]>) flatFieldVectorsWriter.getVectors(),
|
||||
fieldInfo.getVectorDimension()));
|
||||
};
|
||||
|
|
|
@ -21,12 +21,12 @@ import static org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer.quantize
|
|||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
|
||||
/**
|
||||
|
@ -45,9 +45,9 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues)
|
||||
throws IOException {
|
||||
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
return new ScalarQuantizedRandomVectorScorerSupplier(
|
||||
quantizedByteVectorValues, similarityFunction);
|
||||
}
|
||||
|
@ -57,11 +57,9 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target)
|
||||
throws IOException {
|
||||
if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) {
|
||||
ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer();
|
||||
byte[] targetBytes = new byte[target.length];
|
||||
float offsetCorrection =
|
||||
|
@ -79,9 +77,7 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] target)
|
||||
VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target)
|
||||
throws IOException {
|
||||
return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target);
|
||||
}
|
||||
|
@ -96,7 +92,7 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
float offsetCorrection,
|
||||
VectorSimilarityFunction sim,
|
||||
float constMultiplier,
|
||||
RandomAccessQuantizedByteVectorValues values) {
|
||||
QuantizedByteVectorValues values) {
|
||||
return switch (sim) {
|
||||
case EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes);
|
||||
case COSINE, DOT_PRODUCT ->
|
||||
|
@ -120,7 +116,7 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
byte[] targetBytes,
|
||||
float offsetCorrection,
|
||||
float constMultiplier,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
QuantizedByteVectorValues values,
|
||||
FloatToFloatFunction scoreAdjustmentFunction) {
|
||||
if (values.getScalarQuantizer().getBits() <= 4) {
|
||||
if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) {
|
||||
|
@ -137,10 +133,9 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final float constMultiplier;
|
||||
private final byte[] targetBytes;
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final QuantizedByteVectorValues values;
|
||||
|
||||
private Euclidean(
|
||||
RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) {
|
||||
private Euclidean(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) {
|
||||
super(values);
|
||||
this.values = values;
|
||||
this.constMultiplier = constMultiplier;
|
||||
|
@ -159,13 +154,13 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
/** Calculates dot product on quantized vectors, applying the appropriate corrections */
|
||||
private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final float constMultiplier;
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final QuantizedByteVectorValues values;
|
||||
private final byte[] targetBytes;
|
||||
private final float offsetCorrection;
|
||||
private final FloatToFloatFunction scoreAdjustmentFunction;
|
||||
|
||||
public DotProduct(
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
QuantizedByteVectorValues values,
|
||||
float constMultiplier,
|
||||
byte[] targetBytes,
|
||||
float offsetCorrection,
|
||||
|
@ -193,14 +188,14 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
private static class CompressedInt4DotProduct
|
||||
extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final float constMultiplier;
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final QuantizedByteVectorValues values;
|
||||
private final byte[] compressedVector;
|
||||
private final byte[] targetBytes;
|
||||
private final float offsetCorrection;
|
||||
private final FloatToFloatFunction scoreAdjustmentFunction;
|
||||
|
||||
private CompressedInt4DotProduct(
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
QuantizedByteVectorValues values,
|
||||
float constMultiplier,
|
||||
byte[] targetBytes,
|
||||
float offsetCorrection,
|
||||
|
@ -231,13 +226,13 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
|
||||
private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
|
||||
private final float constMultiplier;
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final QuantizedByteVectorValues values;
|
||||
private final byte[] targetBytes;
|
||||
private final float offsetCorrection;
|
||||
private final FloatToFloatFunction scoreAdjustmentFunction;
|
||||
|
||||
public Int4DotProduct(
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
QuantizedByteVectorValues values,
|
||||
float constMultiplier,
|
||||
byte[] targetBytes,
|
||||
float offsetCorrection,
|
||||
|
@ -271,13 +266,12 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
|
|||
implements RandomVectorScorerSupplier {
|
||||
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final RandomAccessQuantizedByteVectorValues values1;
|
||||
private final RandomAccessQuantizedByteVectorValues values2;
|
||||
private final QuantizedByteVectorValues values;
|
||||
private final QuantizedByteVectorValues values1;
|
||||
private final QuantizedByteVectorValues values2;
|
||||
|
||||
public ScalarQuantizedRandomVectorScorerSupplier(
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
VectorSimilarityFunction vectorSimilarityFunction)
|
||||
QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction)
|
||||
throws IOException {
|
||||
this.values = values;
|
||||
this.values1 = values.copy();
|
||||
|
|
|
@ -402,10 +402,10 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
|
||||
private static final class QuantizedVectorValues extends FloatVectorValues {
|
||||
private final FloatVectorValues rawVectorValues;
|
||||
private final OffHeapQuantizedByteVectorValues quantizedVectorValues;
|
||||
private final QuantizedByteVectorValues quantizedVectorValues;
|
||||
|
||||
QuantizedVectorValues(
|
||||
FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) {
|
||||
FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) {
|
||||
this.rawVectorValues = rawVectorValues;
|
||||
this.quantizedVectorValues = quantizedVectorValues;
|
||||
}
|
||||
|
@ -421,34 +421,28 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return rawVectorValues.vectorValue();
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
return rawVectorValues.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return rawVectorValues.docID();
|
||||
public int ordToDoc(int ord) {
|
||||
return rawVectorValues.ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int rawDocId = rawVectorValues.nextDoc();
|
||||
int quantizedDocId = quantizedVectorValues.nextDoc();
|
||||
assert rawDocId == quantizedDocId;
|
||||
return quantizedDocId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
int rawDocId = rawVectorValues.advance(target);
|
||||
int quantizedDocId = quantizedVectorValues.advance(target);
|
||||
assert rawDocId == quantizedDocId;
|
||||
return quantizedDocId;
|
||||
public QuantizedVectorValues copy() throws IOException {
|
||||
return new QuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy());
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
return quantizedVectorValues.scorer(query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return rawVectorValues.iterator();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,9 +19,7 @@ package org.apache.lucene.codecs.lucene99;
|
|||
|
||||
import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.*;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
|
||||
|
||||
|
@ -45,6 +43,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
|
@ -653,12 +652,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|| bits <= 4
|
||||
|| shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
|
||||
int numVectors = 0;
|
||||
FloatVectorValues vectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
DocIdSetIterator iter =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)
|
||||
.iterator();
|
||||
// iterate vectorValues and increment numVectors
|
||||
for (int doc = vectorValues.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = vectorValues.nextDoc()) {
|
||||
for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) {
|
||||
numVectors++;
|
||||
}
|
||||
return buildScalarQuantizer(
|
||||
|
@ -730,11 +728,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
? OffHeapQuantizedByteVectorValues.compressedArray(
|
||||
quantizedByteVectorValues.dimension(), bits)
|
||||
: null;
|
||||
for (int docV = quantizedByteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = quantizedByteVectorValues.nextDoc()) {
|
||||
KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator();
|
||||
for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) {
|
||||
// write vector
|
||||
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
|
||||
byte[] binaryValue = quantizedByteVectorValues.vectorValue(iter.index());
|
||||
assert binaryValue.length == quantizedByteVectorValues.dimension()
|
||||
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
|
||||
if (compressedVector != null) {
|
||||
|
@ -743,7 +740,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
} else {
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
}
|
||||
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
|
||||
output.writeInt(
|
||||
Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(iter.index())));
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
|
@ -855,7 +853,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|
||||
static class FloatVectorWrapper extends FloatVectorValues {
|
||||
private final List<float[]> vectorList;
|
||||
protected int curDoc = -1;
|
||||
|
||||
FloatVectorWrapper(List<float[]> vectorList) {
|
||||
this.vectorList = vectorList;
|
||||
|
@ -872,51 +869,42 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
if (curDoc == -1 || curDoc >= vectorList.size()) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
return vectorList.get(curDoc);
|
||||
public FloatVectorValues copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curDoc >= vectorList.size()) {
|
||||
return NO_MORE_DOCS;
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
if (ord < 0 || ord >= vectorList.size()) {
|
||||
throw new IOException("vector ord " + ord + " out of bounds");
|
||||
}
|
||||
return curDoc;
|
||||
return vectorList.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
curDoc++;
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
curDoc = target;
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
}
|
||||
|
||||
static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
|
||||
private final QuantizedByteVectorValues values;
|
||||
private final KnnVectorValues.DocIndexIterator iterator;
|
||||
|
||||
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
|
||||
super(docMap);
|
||||
this.values = values;
|
||||
assert values.docID() == -1;
|
||||
iterator = values.iterator();
|
||||
assert iterator.docID() == -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return values.nextDoc();
|
||||
return iterator.nextDoc();
|
||||
}
|
||||
|
||||
public int index() {
|
||||
return iterator.index();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -973,7 +961,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
private final DocIDMerger<QuantizedByteVectorValueSub> docIdMerger;
|
||||
private final int size;
|
||||
|
||||
private int docId;
|
||||
private QuantizedByteVectorValueSub current;
|
||||
|
||||
private MergedQuantizedVectorValues(
|
||||
|
@ -985,33 +972,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
return current.values.vectorValue(current.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return new CompositeIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1025,20 +995,59 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public float getScoreCorrectionConstant() throws IOException {
|
||||
return current.values.getScoreCorrectionConstant();
|
||||
public float getScoreCorrectionConstant(int ord) throws IOException {
|
||||
return current.values.getScoreCorrectionConstant(current.index());
|
||||
}
|
||||
|
||||
private class CompositeIterator extends DocIndexIterator {
|
||||
private int docId;
|
||||
private int ord;
|
||||
|
||||
public CompositeIterator() {
|
||||
docId = -1;
|
||||
ord = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
public int index() {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
ord = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
++ord;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return size;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
|
||||
private final FloatVectorValues values;
|
||||
private final ScalarQuantizer quantizer;
|
||||
private final byte[] quantizedVector;
|
||||
private int lastOrd = -1;
|
||||
private float offsetValue = 0f;
|
||||
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
@ -1054,7 +1063,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public float getScoreCorrectionConstant() {
|
||||
public float getScoreCorrectionConstant(int ord) {
|
||||
if (ord != lastOrd) {
|
||||
throw new IllegalStateException(
|
||||
"attempt to retrieve score correction for different ord "
|
||||
+ ord
|
||||
+ " than the quantization was done for: "
|
||||
+ lastOrd);
|
||||
}
|
||||
return offsetValue;
|
||||
}
|
||||
|
||||
|
@ -1069,41 +1085,31 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
if (ord != lastOrd) {
|
||||
offsetValue = quantize(ord);
|
||||
lastOrd = ord;
|
||||
}
|
||||
return quantizedVector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return values.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int doc = values.nextDoc();
|
||||
if (doc != NO_MORE_DOCS) {
|
||||
quantize();
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
int doc = values.advance(target);
|
||||
if (doc != NO_MORE_DOCS) {
|
||||
quantize();
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
private void quantize() throws IOException {
|
||||
offsetValue =
|
||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
||||
private float quantize(int ord) throws IOException {
|
||||
return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return values.ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return values.iterator();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1160,9 +1166,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public float getScoreCorrectionConstant() throws IOException {
|
||||
public float getScoreCorrectionConstant(int ord) throws IOException {
|
||||
return scalarQuantizer.recalculateCorrectiveOffset(
|
||||
in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction);
|
||||
in.vectorValue(ord), oldScalarQuantizer, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1176,35 +1182,24 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return in.vectorValue();
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
return in.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return in.docID();
|
||||
public int ordToDoc(int ord) {
|
||||
return in.ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return in.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return in.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return in.iterator();
|
||||
}
|
||||
}
|
||||
|
||||
static final class NormalizedFloatVectorValues extends FloatVectorValues {
|
||||
private final FloatVectorValues values;
|
||||
private final float[] normalizedVector;
|
||||
int curDoc = -1;
|
||||
|
||||
public NormalizedFloatVectorValues(FloatVectorValues values) {
|
||||
this.values = values;
|
||||
|
@ -1222,38 +1217,25 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
public int ordToDoc(int ord) {
|
||||
return values.ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length);
|
||||
VectorUtil.l2normalize(normalizedVector);
|
||||
return normalizedVector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
public DocIndexIterator iterator() {
|
||||
return values.iterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return values.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
curDoc = values.nextDoc();
|
||||
if (curDoc != NO_MORE_DOCS) {
|
||||
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||
VectorUtil.l2normalize(normalizedVector);
|
||||
}
|
||||
return curDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
curDoc = values.advance(target);
|
||||
if (curDoc != NO_MORE_DOCS) {
|
||||
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||
VectorUtil.l2normalize(normalizedVector);
|
||||
}
|
||||
return curDoc;
|
||||
public NormalizedFloatVectorValues copy() throws IOException {
|
||||
return new NormalizedFloatVectorValues(values.copy());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,15 +30,13 @@ import org.apache.lucene.util.Bits;
|
|||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
|
||||
/**
|
||||
* Read the quantized vector values and their score correction values from the index input. This
|
||||
* supports both iterated and random access.
|
||||
*/
|
||||
public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues
|
||||
implements RandomAccessQuantizedByteVectorValues {
|
||||
public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -141,11 +139,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getScoreCorrectionConstant() {
|
||||
return scoreCorrectionConstant[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getScoreCorrectionConstant(int targetOrd) throws IOException {
|
||||
if (lastOrd == targetOrd) {
|
||||
|
@ -213,8 +206,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
*/
|
||||
public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(
|
||||
int dimension,
|
||||
int size,
|
||||
|
@ -226,30 +217,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(
|
||||
|
@ -270,20 +237,26 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
DenseOffHeapVectorValues copy = copy();
|
||||
DocIndexIterator iterator = copy.iterator();
|
||||
RandomVectorScorer vectorScorer =
|
||||
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorScorer.score(copy.doc);
|
||||
return vectorScorer.score(iterator.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
@ -312,24 +285,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
public DocIndexIterator iterator() {
|
||||
return IndexedDISI.asDocIndexIterator(disi);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -372,17 +329,18 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
SparseOffHeapVectorValues copy = copy();
|
||||
DocIndexIterator iterator = copy.iterator();
|
||||
RandomVectorScorer vectorScorer =
|
||||
vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target);
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return vectorScorer.score(copy.disi.index());
|
||||
return vectorScorer.score(iterator.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return copy;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -404,8 +362,6 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
null);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
|
@ -417,23 +373,8 @@ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVect
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
||||
/**
|
||||
|
@ -27,34 +27,21 @@ import org.apache.lucene.search.VectorScorer;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class ByteVectorValues extends DocIdSetIterator {
|
||||
public abstract class ByteVectorValues extends KnnVectorValues {
|
||||
|
||||
/** Sole constructor */
|
||||
protected ByteVectorValues() {}
|
||||
|
||||
/** Return the dimension of the vectors */
|
||||
public abstract int dimension();
|
||||
|
||||
/**
|
||||
* Return the number of vectors for this field.
|
||||
*
|
||||
* @return the number of vectors returned by this iterator
|
||||
*/
|
||||
public abstract int size();
|
||||
|
||||
@Override
|
||||
public final long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the vector value for the current document ID. It is illegal to call this method when the
|
||||
* iterator is not positioned: before advancing, or after failing to advance. The returned array
|
||||
* may be shared across calls, re-used, and modified as the iterator advances.
|
||||
* Return the vector value for the given vector ordinal which must be in [0, size() - 1],
|
||||
* otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls.
|
||||
*
|
||||
* @return the vector value
|
||||
*/
|
||||
public abstract byte[] vectorValue() throws IOException;
|
||||
public abstract byte[] vectorValue(int ord) throws IOException;
|
||||
|
||||
@Override
|
||||
public abstract ByteVectorValues copy() throws IOException;
|
||||
|
||||
/**
|
||||
* Checks the Vector Encoding of a field
|
||||
|
@ -78,12 +65,53 @@ public abstract class ByteVectorValues extends DocIdSetIterator {
|
|||
}
|
||||
|
||||
/**
|
||||
* Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not
|
||||
* the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and
|
||||
* iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}.
|
||||
* Return a {@link VectorScorer} for the given query vector.
|
||||
*
|
||||
* @param query the query vector
|
||||
* @return a {@link VectorScorer} instance or null
|
||||
*/
|
||||
public abstract VectorScorer scorer(byte[] query) throws IOException;
|
||||
public VectorScorer scorer(byte[] query) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorEncoding getEncoding() {
|
||||
return VectorEncoding.BYTE;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link ByteVectorValues} from a list of byte arrays.
|
||||
*
|
||||
* @param vectors the list of byte arrays
|
||||
* @param dim the dimension of the vectors
|
||||
* @return a {@link ByteVectorValues} instancec
|
||||
*/
|
||||
public static ByteVectorValues fromBytes(List<byte[]> vectors, int dim) {
|
||||
return new ByteVectorValues() {
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2760,16 +2760,16 @@ public final class CheckIndex implements Closeable {
|
|||
CheckIndex.Status.VectorValuesStatus status,
|
||||
CodecReader codecReader)
|
||||
throws IOException {
|
||||
int docCount = 0;
|
||||
int count = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
while (count < values.size()) {
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
if (values.ordToDoc(count) % everyNdoc == 0) {
|
||||
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
|
||||
if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) {
|
||||
codecReader
|
||||
.getVectorReader()
|
||||
.search(fieldInfo.name, values.vectorValue(), collector, null);
|
||||
.search(fieldInfo.name, values.vectorValue(count), collector, null);
|
||||
TopDocs docs = collector.topDocs();
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
|
@ -2777,7 +2777,7 @@ public final class CheckIndex implements Closeable {
|
|||
}
|
||||
}
|
||||
}
|
||||
int valueLength = values.vectorValue().length;
|
||||
int valueLength = values.vectorValue(count).length;
|
||||
if (valueLength != fieldInfo.getVectorDimension()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
|
@ -2787,19 +2787,19 @@ public final class CheckIndex implements Closeable {
|
|||
+ " not matching the field's dimension="
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
++docCount;
|
||||
++count;
|
||||
}
|
||||
if (docCount != values.size()) {
|
||||
if (count != values.size()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has size="
|
||||
+ values.size()
|
||||
+ " but when iterated, returns "
|
||||
+ docCount
|
||||
+ count
|
||||
+ " docs with values");
|
||||
}
|
||||
status.totalVectorValues += docCount;
|
||||
status.totalVectorValues += count;
|
||||
}
|
||||
|
||||
private static void checkByteVectorValues(
|
||||
|
@ -2808,21 +2808,23 @@ public final class CheckIndex implements Closeable {
|
|||
CheckIndex.Status.VectorValuesStatus status,
|
||||
CodecReader codecReader)
|
||||
throws IOException {
|
||||
int docCount = 0;
|
||||
int count = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
while (count < values.size()) {
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (supportsSearch && values.docID() % everyNdoc == 0) {
|
||||
if (supportsSearch && values.ordToDoc(count) % everyNdoc == 0) {
|
||||
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
|
||||
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
|
||||
codecReader
|
||||
.getVectorReader()
|
||||
.search(fieldInfo.name, values.vectorValue(count), collector, null);
|
||||
TopDocs docs = collector.topDocs();
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
}
|
||||
}
|
||||
int valueLength = values.vectorValue().length;
|
||||
int valueLength = values.vectorValue(count).length;
|
||||
if (valueLength != fieldInfo.getVectorDimension()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
|
@ -2832,19 +2834,19 @@ public final class CheckIndex implements Closeable {
|
|||
+ " not matching the field's dimension="
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
++docCount;
|
||||
++count;
|
||||
}
|
||||
if (docCount != values.size()) {
|
||||
if (count != values.size()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has size="
|
||||
+ values.size()
|
||||
+ " but when iterated, returns "
|
||||
+ docCount
|
||||
+ count
|
||||
+ " docs with values");
|
||||
}
|
||||
status.totalVectorValues += docCount;
|
||||
status.totalVectorValues += count;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -429,37 +429,10 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
}
|
||||
|
||||
private class ExitableFloatVectorValues extends FloatVectorValues {
|
||||
private int docToCheck;
|
||||
private final FloatVectorValues vectorValues;
|
||||
|
||||
public ExitableFloatVectorValues(FloatVectorValues vectorValues) {
|
||||
this.vectorValues = vectorValues;
|
||||
docToCheck = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
final int advance = vectorValues.advance(target);
|
||||
if (advance >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return advance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return vectorValues.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
final int nextDoc = vectorValues.nextDoc();
|
||||
if (nextDoc >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return nextDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -468,8 +441,13 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectorValues.vectorValue();
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
return vectorValues.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return vectorValues.ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -477,61 +455,27 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return vectorValues.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createExitableIterator(vectorValues.iterator(), queryTimeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
|
||||
* if {@link Thread#interrupted()} returns true.
|
||||
*/
|
||||
private void checkAndThrow() {
|
||||
if (queryTimeout.shouldExit()) {
|
||||
throw new ExitingReaderException(
|
||||
"The request took too long to iterate over vector values. Timeout: "
|
||||
+ queryTimeout.toString()
|
||||
+ ", FloatVectorValues="
|
||||
+ in);
|
||||
} else if (Thread.interrupted()) {
|
||||
throw new ExitingReaderException(
|
||||
"Interrupted while iterating over vector values. FloatVectorValues=" + in);
|
||||
}
|
||||
@Override
|
||||
public FloatVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
private class ExitableByteVectorValues extends ByteVectorValues {
|
||||
private int docToCheck;
|
||||
private final ByteVectorValues vectorValues;
|
||||
|
||||
public ExitableByteVectorValues(ByteVectorValues vectorValues) {
|
||||
this.vectorValues = vectorValues;
|
||||
docToCheck = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
final int advance = vectorValues.advance(target);
|
||||
if (advance >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return advance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return vectorValues.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
final int nextDoc = vectorValues.nextDoc();
|
||||
if (nextDoc >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return nextDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -545,8 +489,18 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValues.vectorValue();
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
return vectorValues.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return vectorValues.ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createExitableIterator(vectorValues.iterator(), queryTimeout);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -554,23 +508,66 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return vectorValues.scorer(target);
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
|
||||
* if {@link Thread#interrupted()} returns true.
|
||||
*/
|
||||
@Override
|
||||
public ByteVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static KnnVectorValues.DocIndexIterator createExitableIterator(
|
||||
KnnVectorValues.DocIndexIterator delegate, QueryTimeout queryTimeout) {
|
||||
return new KnnVectorValues.DocIndexIterator() {
|
||||
private int nextCheck;
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return delegate.index();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return delegate.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int doc = delegate.nextDoc();
|
||||
if (doc >= nextCheck) {
|
||||
checkAndThrow();
|
||||
nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return delegate.cost();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
int doc = delegate.advance(target);
|
||||
if (doc >= nextCheck) {
|
||||
checkAndThrow();
|
||||
nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
private void checkAndThrow() {
|
||||
if (queryTimeout.shouldExit()) {
|
||||
throw new ExitingReaderException(
|
||||
"The request took too long to iterate over vector values. Timeout: "
|
||||
"The request took too long to iterate over knn vector values. Timeout: "
|
||||
+ queryTimeout.toString()
|
||||
+ ", ByteVectorValues="
|
||||
+ in);
|
||||
+ ", KnnVectorValues="
|
||||
+ delegate);
|
||||
} else if (Thread.interrupted()) {
|
||||
throw new ExitingReaderException(
|
||||
"Interrupted while iterating over vector values. ByteVectorValues=" + in);
|
||||
}
|
||||
"Interrupted while iterating over knn vector values. KnnVectorValues=" + delegate);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/** Wrapper class for another PointValues implementation that is used by ExitableFields. */
|
||||
|
@ -683,7 +680,7 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
if (queryTimeout.shouldExit()) {
|
||||
throw new ExitingReaderException(
|
||||
"The request took too long to intersect point values. Timeout: "
|
||||
+ queryTimeout.toString()
|
||||
+ queryTimeout
|
||||
+ ", PointValues="
|
||||
+ pointValues);
|
||||
} else if (Thread.interrupted()) {
|
||||
|
@ -815,7 +812,7 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
/** Wrapper class for another Terms implementation that is used by ExitableFields. */
|
||||
public static class ExitableTerms extends FilterTerms {
|
||||
|
||||
private QueryTimeout queryTimeout;
|
||||
private final QueryTimeout queryTimeout;
|
||||
|
||||
/** Constructor * */
|
||||
public ExitableTerms(Terms terms, QueryTimeout queryTimeout) {
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
|
||||
/**
|
||||
|
@ -27,34 +27,21 @@ import org.apache.lucene.search.VectorScorer;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class FloatVectorValues extends DocIdSetIterator {
|
||||
public abstract class FloatVectorValues extends KnnVectorValues {
|
||||
|
||||
/** Sole constructor */
|
||||
protected FloatVectorValues() {}
|
||||
|
||||
/** Return the dimension of the vectors */
|
||||
public abstract int dimension();
|
||||
|
||||
/**
|
||||
* Return the number of vectors for this field.
|
||||
*
|
||||
* @return the number of vectors returned by this iterator
|
||||
*/
|
||||
public abstract int size();
|
||||
|
||||
@Override
|
||||
public final long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the vector value for the current document ID. It is illegal to call this method when the
|
||||
* iterator is not positioned: before advancing, or after failing to advance. The returned array
|
||||
* may be shared across calls, re-used, and modified as the iterator advances.
|
||||
* Return the vector value for the given vector ordinal which must be in [0, size() - 1],
|
||||
* otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls.
|
||||
*
|
||||
* @return the vector value
|
||||
*/
|
||||
public abstract float[] vectorValue() throws IOException;
|
||||
public abstract float[] vectorValue(int ord) throws IOException;
|
||||
|
||||
@Override
|
||||
public abstract FloatVectorValues copy() throws IOException;
|
||||
|
||||
/**
|
||||
* Checks the Vector Encoding of a field
|
||||
|
@ -79,12 +66,53 @@ public abstract class FloatVectorValues extends DocIdSetIterator {
|
|||
|
||||
/**
|
||||
* Return a {@link VectorScorer} for the given query vector and the current {@link
|
||||
* FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for
|
||||
* this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the
|
||||
* iteration of this {@link FloatVectorValues}.
|
||||
* FloatVectorValues}.
|
||||
*
|
||||
* @param query the query vector
|
||||
* @param target the query vector
|
||||
* @return a {@link VectorScorer} instance or null
|
||||
*/
|
||||
public abstract VectorScorer scorer(float[] query) throws IOException;
|
||||
public VectorScorer scorer(float[] target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorEncoding getEncoding() {
|
||||
return VectorEncoding.FLOAT32;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link FloatVectorValues} from a list of float arrays.
|
||||
*
|
||||
* @param vectors the list of float arrays
|
||||
* @param dim the dimension of the vectors
|
||||
* @return a {@link FloatVectorValues} instance
|
||||
*/
|
||||
public static FloatVectorValues fromFloats(List<float[]> vectors, int dim) {
|
||||
return new FloatVectorValues() {
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,229 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* This class abstracts addressing of document vector values indexed as {@link KnnFloatVectorField}
|
||||
* or {@link KnnByteVectorField}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class KnnVectorValues {
|
||||
|
||||
/** Return the dimension of the vectors */
|
||||
public abstract int dimension();
|
||||
|
||||
/**
|
||||
* Return the number of vectors for this field.
|
||||
*
|
||||
* @return the number of vectors returned by this iterator
|
||||
*/
|
||||
public abstract int size();
|
||||
|
||||
/**
|
||||
* Return the docid of the document indexed with the given vector ordinal. This default
|
||||
* implementation returns the argument and is appropriate for dense values implementations where
|
||||
* every doc has a single value.
|
||||
*/
|
||||
public int ordToDoc(int ord) {
|
||||
return ord;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access
|
||||
* different values at once, to avoid overwriting the underlying vector returned.
|
||||
*/
|
||||
public abstract KnnVectorValues copy() throws IOException;
|
||||
|
||||
/** Returns the vector byte length, defaults to dimension multiplied by float byte size */
|
||||
public int getVectorByteLength() {
|
||||
return dimension() * getEncoding().byteSize;
|
||||
}
|
||||
|
||||
/** The vector encoding of these values. */
|
||||
public abstract VectorEncoding getEncoding();
|
||||
|
||||
/** Returns a Bits accepting docs accepted by the argument and having a vector value */
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
// FIXME: change default to return acceptDocs and provide this impl
|
||||
// somewhere more specialized (in every non-dense impl).
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
return new Bits() {
|
||||
@Override
|
||||
public boolean get(int index) {
|
||||
return acceptDocs.get(ordToDoc(index));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return size();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/** Create an iterator for this instance. */
|
||||
public DocIndexIterator iterator() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
/**
|
||||
* A DocIdSetIterator that also provides an index() method tracking a distinct ordinal for a
|
||||
* vector associated with each doc.
|
||||
*/
|
||||
public abstract static class DocIndexIterator extends DocIdSetIterator {
|
||||
|
||||
/** return the value index (aka "ordinal" or "ord") corresponding to the current doc */
|
||||
public abstract int index();
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an iterator for instances where every doc has a value, and the value ordinals are equal
|
||||
* to the docids.
|
||||
*/
|
||||
protected DocIndexIterator createDenseIterator() {
|
||||
return new DocIndexIterator() {
|
||||
|
||||
int doc = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
if (doc >= size() - 1) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
} else {
|
||||
return ++doc;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
if (target >= size()) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return size();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an iterator from a DocIdSetIterator indicating which docs have values, and for which
|
||||
* ordinals increase monotonically with docid.
|
||||
*/
|
||||
protected static DocIndexIterator fromDISI(DocIdSetIterator docsWithField) {
|
||||
return new DocIndexIterator() {
|
||||
|
||||
int ord = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docsWithField.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
if (docID() == NO_MORE_DOCS) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
ord++;
|
||||
return docsWithField.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return docsWithField.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return docsWithField.cost();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic
|
||||
* (docid increases when ordinal does).
|
||||
*/
|
||||
protected DocIndexIterator createSparseIterator() {
|
||||
return new DocIndexIterator() {
|
||||
private int ord = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (ord == -1) {
|
||||
return -1;
|
||||
}
|
||||
if (ord == NO_MORE_DOCS) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return ordToDoc(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
if (ord >= size() - 1) {
|
||||
ord = NO_MORE_DOCS;
|
||||
} else {
|
||||
++ord;
|
||||
}
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return size();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -34,9 +34,7 @@ import org.apache.lucene.codecs.StoredFieldsReader;
|
|||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues;
|
||||
import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
|
@ -303,38 +301,21 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
|
|||
}
|
||||
}
|
||||
|
||||
private record DocValuesSub<T extends DocIdSetIterator>(T sub, int docStart, int docEnd) {}
|
||||
private record DocValuesSub<T extends KnnVectorValues>(T sub, int docStart, int ordStart) {}
|
||||
|
||||
private static class MergedDocIdSetIterator<T extends DocIdSetIterator> extends DocIdSetIterator {
|
||||
private static class MergedDocIterator<T extends KnnVectorValues>
|
||||
extends KnnVectorValues.DocIndexIterator {
|
||||
|
||||
final Iterator<DocValuesSub<T>> it;
|
||||
final long cost;
|
||||
DocValuesSub<T> current;
|
||||
int currentIndex = 0;
|
||||
KnnVectorValues.DocIndexIterator currentIterator;
|
||||
int ord = -1;
|
||||
int doc = -1;
|
||||
|
||||
MergedDocIdSetIterator(List<DocValuesSub<T>> subs) {
|
||||
long cost = 0;
|
||||
for (DocValuesSub<T> sub : subs) {
|
||||
if (sub.sub != null) {
|
||||
cost += sub.sub.cost();
|
||||
}
|
||||
}
|
||||
this.cost = cost;
|
||||
MergedDocIterator(List<DocValuesSub<T>> subs) {
|
||||
this.it = subs.iterator();
|
||||
current = it.next();
|
||||
}
|
||||
|
||||
private boolean advanceSub(int target) {
|
||||
while (current.sub == null || current.docEnd <= target) {
|
||||
if (it.hasNext() == false) {
|
||||
doc = NO_MORE_DOCS;
|
||||
return false;
|
||||
}
|
||||
current = it.next();
|
||||
currentIndex++;
|
||||
}
|
||||
return true;
|
||||
currentIterator = currentIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -342,41 +323,47 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
|
|||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (true) {
|
||||
if (current.sub != null) {
|
||||
int next = current.sub.nextDoc();
|
||||
int next = currentIterator.nextDoc();
|
||||
if (next != NO_MORE_DOCS) {
|
||||
++ord;
|
||||
return doc = current.docStart + next;
|
||||
}
|
||||
}
|
||||
if (it.hasNext() == false) {
|
||||
ord = NO_MORE_DOCS;
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
current = it.next();
|
||||
currentIndex++;
|
||||
currentIterator = currentIterator();
|
||||
ord = current.ordStart - 1;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
while (true) {
|
||||
if (advanceSub(target) == false) {
|
||||
return DocIdSetIterator.NO_MORE_DOCS;
|
||||
}
|
||||
int next = current.sub.advance(target - current.docStart);
|
||||
if (next == DocIdSetIterator.NO_MORE_DOCS) {
|
||||
target = current.docEnd;
|
||||
private KnnVectorValues.DocIndexIterator currentIterator() {
|
||||
if (current.sub != null) {
|
||||
return current.sub.iterator();
|
||||
} else {
|
||||
return doc = current.docStart + next;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return cost;
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -848,55 +835,75 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
|
|||
int size = 0;
|
||||
for (CodecReader reader : codecReaders) {
|
||||
FloatVectorValues values = reader.getFloatVectorValues(field);
|
||||
subs.add(new DocValuesSub<>(values, docStarts[i], size));
|
||||
if (values != null) {
|
||||
if (dimension == -1) {
|
||||
dimension = values.dimension();
|
||||
}
|
||||
size += values.size();
|
||||
}
|
||||
subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1]));
|
||||
i++;
|
||||
}
|
||||
final int finalDimension = dimension;
|
||||
final int finalSize = size;
|
||||
MergedDocIdSetIterator<FloatVectorValues> mergedIterator = new MergedDocIdSetIterator<>(subs);
|
||||
return new FloatVectorValues() {
|
||||
return new MergedFloatVectorValues(dimension, size, subs);
|
||||
}
|
||||
|
||||
class MergedFloatVectorValues extends FloatVectorValues {
|
||||
final int dimension;
|
||||
final int size;
|
||||
final DocValuesSub<?>[] subs;
|
||||
final MergedDocIterator<FloatVectorValues> iter;
|
||||
final int[] starts;
|
||||
int lastSubIndex;
|
||||
|
||||
MergedFloatVectorValues(int dimension, int size, List<DocValuesSub<FloatVectorValues>> subs) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.subs = subs.toArray(new DocValuesSub<?>[0]);
|
||||
iter = new MergedDocIterator<>(subs);
|
||||
// [0, start(1), ..., size] - we want the extra element
|
||||
// to avoid checking for out-of-array bounds
|
||||
starts = new int[subs.size() + 1];
|
||||
for (int i = 0; i < subs.size(); i++) {
|
||||
starts[i] = subs.get(i).ordStart;
|
||||
}
|
||||
starts[starts.length - 1] = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MergedDocIterator<FloatVectorValues> iterator() {
|
||||
return iter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return finalDimension;
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return finalSize;
|
||||
return size;
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public FloatVectorValues copy() throws IOException {
|
||||
List<DocValuesSub<FloatVectorValues>> subsCopy = new ArrayList<>();
|
||||
for (Object sub : subs) {
|
||||
subsCopy.add((DocValuesSub<FloatVectorValues>) sub);
|
||||
}
|
||||
return new MergedFloatVectorValues(dimension, size, subsCopy);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return mergedIterator.current.sub.vectorValue();
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
assert ord >= 0 && ord < size;
|
||||
// We need to implement fully random-access API here in order to support callers like
|
||||
// SortingCodecReader that
|
||||
// rely on it.
|
||||
lastSubIndex = findSub(ord, lastSubIndex, starts);
|
||||
return ((FloatVectorValues) subs[lastSubIndex].sub)
|
||||
.vectorValue(ord - subs[lastSubIndex].ordStart);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return mergedIterator.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return mergedIterator.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return mergedIterator.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -907,55 +914,96 @@ final class SlowCompositeCodecReaderWrapper extends CodecReader {
|
|||
int size = 0;
|
||||
for (CodecReader reader : codecReaders) {
|
||||
ByteVectorValues values = reader.getByteVectorValues(field);
|
||||
subs.add(new DocValuesSub<>(values, docStarts[i], size));
|
||||
if (values != null) {
|
||||
if (dimension == -1) {
|
||||
dimension = values.dimension();
|
||||
}
|
||||
size += values.size();
|
||||
}
|
||||
subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1]));
|
||||
i++;
|
||||
}
|
||||
final int finalDimension = dimension;
|
||||
final int finalSize = size;
|
||||
MergedDocIdSetIterator<ByteVectorValues> mergedIterator = new MergedDocIdSetIterator<>(subs);
|
||||
return new ByteVectorValues() {
|
||||
return new MergedByteVectorValues(dimension, size, subs);
|
||||
}
|
||||
|
||||
class MergedByteVectorValues extends ByteVectorValues {
|
||||
final int dimension;
|
||||
final int size;
|
||||
final DocValuesSub<?>[] subs;
|
||||
final MergedDocIterator<ByteVectorValues> iter;
|
||||
final int[] starts;
|
||||
int lastSubIndex;
|
||||
|
||||
MergedByteVectorValues(int dimension, int size, List<DocValuesSub<ByteVectorValues>> subs) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.subs = subs.toArray(new DocValuesSub<?>[0]);
|
||||
iter = new MergedDocIterator<>(subs);
|
||||
// [0, start(1), ..., size] - we want the extra element
|
||||
// to avoid checking for out-of-array bounds
|
||||
starts = new int[subs.size() + 1];
|
||||
for (int i = 0; i < subs.size(); i++) {
|
||||
starts[i] = subs.get(i).ordStart;
|
||||
}
|
||||
starts[starts.length - 1] = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MergedDocIterator<ByteVectorValues> iterator() {
|
||||
return iter;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return finalDimension;
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return finalSize;
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return mergedIterator.current.sub.vectorValue();
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
assert ord >= 0 && ord < size;
|
||||
// We need to implement fully random-access API here in order to support callers like
|
||||
// SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some
|
||||
// repetition.
|
||||
lastSubIndex = findSub(ord, lastSubIndex, starts);
|
||||
return ((ByteVectorValues) subs[lastSubIndex].sub)
|
||||
.vectorValue(ord - subs[lastSubIndex].ordStart);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
@Override
|
||||
public int docID() {
|
||||
return mergedIterator.docID();
|
||||
public ByteVectorValues copy() throws IOException {
|
||||
List<DocValuesSub<ByteVectorValues>> newSubs = new ArrayList<>();
|
||||
for (Object sub : subs) {
|
||||
newSubs.add((DocValuesSub<ByteVectorValues>) sub);
|
||||
}
|
||||
return new MergedByteVectorValues(dimension, size, newSubs);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return mergedIterator.nextDoc();
|
||||
private static int findSub(int ord, int lastSubIndex, int[] starts) {
|
||||
if (ord >= starts[lastSubIndex]) {
|
||||
if (ord >= starts[lastSubIndex + 1]) {
|
||||
return binarySearchStarts(starts, ord, lastSubIndex + 1, starts.length);
|
||||
}
|
||||
} else {
|
||||
return binarySearchStarts(starts, ord, 0, lastSubIndex);
|
||||
}
|
||||
return lastSubIndex;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return mergedIterator.advance(target);
|
||||
private static int binarySearchStarts(int[] starts, int ord, int from, int to) {
|
||||
int pos = Arrays.binarySearch(starts, from, to, ord);
|
||||
if (pos < 0) {
|
||||
// subtract one since binarySearch returns an *insertion point*
|
||||
return -2 - pos;
|
||||
} else {
|
||||
return pos;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
|||
import java.util.Iterator;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.function.Supplier;
|
||||
import org.apache.lucene.codecs.DocValuesProducer;
|
||||
import org.apache.lucene.codecs.FieldsProducer;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
|
@ -32,10 +33,11 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.apache.lucene.util.IOSupplier;
|
||||
|
@ -206,121 +208,175 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
}
|
||||
|
||||
/** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */
|
||||
private static class SortingFloatVectorValues extends FloatVectorValues {
|
||||
final int size;
|
||||
final int dimension;
|
||||
final FixedBitSet docsWithField;
|
||||
final float[][] vectors;
|
||||
/**
|
||||
* Factory for SortingValuesIterator. This enables us to create new iterators as needed without
|
||||
* recomputing the sorting mappings.
|
||||
*/
|
||||
static class SortingIteratorSupplier implements Supplier<SortingValuesIterator> {
|
||||
private final FixedBitSet docBits;
|
||||
private final int[] docToOrd;
|
||||
private final int size;
|
||||
|
||||
private int docId = -1;
|
||||
|
||||
SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.size = delegate.size();
|
||||
this.dimension = delegate.dimension();
|
||||
docsWithField = new FixedBitSet(sortMap.size());
|
||||
vectors = new float[sortMap.size()][];
|
||||
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(doc);
|
||||
docsWithField.set(newDocID);
|
||||
vectors[newDocID] = delegate.vectorValue().clone();
|
||||
SortingIteratorSupplier(FixedBitSet docBits, int[] docToOrd, int size) {
|
||||
this.docBits = docBits;
|
||||
this.docToOrd = docToOrd;
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SortingValuesIterator get() {
|
||||
return new SortingValuesIterator(docBits, docToOrd, size);
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a factory for SortingValuesIterator. Does the work of computing the (new docId to old
|
||||
* ordinal) mapping, and caches the result, enabling it to create new iterators cheaply.
|
||||
*
|
||||
* @param values the values over which to iterate
|
||||
* @param docMap the mapping from "old" docIds to "new" (sorted) docIds.
|
||||
*/
|
||||
public static SortingIteratorSupplier iteratorSupplier(
|
||||
KnnVectorValues values, Sorter.DocMap docMap) throws IOException {
|
||||
|
||||
final int[] docToOrd = new int[docMap.size()];
|
||||
final FixedBitSet docBits = new FixedBitSet(docMap.size());
|
||||
int count = 0;
|
||||
// Note: docToOrd will contain zero for docids that have no vector. This is OK though
|
||||
// because the iterator cannot be positioned on such docs
|
||||
KnnVectorValues.DocIndexIterator iter = values.iterator();
|
||||
for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) {
|
||||
int newDocId = docMap.oldToNew(doc);
|
||||
if (newDocId != -1) {
|
||||
docToOrd[newDocId] = iter.index();
|
||||
docBits.set(newDocId);
|
||||
++count;
|
||||
}
|
||||
}
|
||||
return new SortingIteratorSupplier(docBits, docToOrd, count);
|
||||
}
|
||||
|
||||
/**
|
||||
* Iterator over KnnVectorValues accepting a mapping to differently-sorted docs. Consequently
|
||||
* index() may skip around, not increasing monotonically as iteration proceeds.
|
||||
*/
|
||||
public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator {
|
||||
private final FixedBitSet docBits;
|
||||
private final DocIdSetIterator docsWithValues;
|
||||
private final int[] docToOrd;
|
||||
|
||||
int doc = -1;
|
||||
|
||||
SortingValuesIterator(FixedBitSet docBits, int[] docToOrd, int size) {
|
||||
this.docBits = docBits;
|
||||
this.docToOrd = docToOrd;
|
||||
docsWithValues = new BitSetIterator(docBits, size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
assert docBits.get(doc);
|
||||
return docToOrd[doc];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(docId + 1);
|
||||
if (doc != NO_MORE_DOCS) {
|
||||
doc = docsWithValues.nextDoc();
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return vectors[docId];
|
||||
public long cost() {
|
||||
return docBits.cardinality();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
if (target >= docsWithField.length()) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docId = docsWithField.nextSetBit(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||
final int size;
|
||||
final int dimension;
|
||||
final FixedBitSet docsWithField;
|
||||
final byte[][] vectors;
|
||||
/** Sorting FloatVectorValues that maps ordinals using the provided sortMap */
|
||||
private static class SortingFloatVectorValues extends FloatVectorValues {
|
||||
final FloatVectorValues delegate;
|
||||
final SortingIteratorSupplier iteratorSupplier;
|
||||
|
||||
private int docId = -1;
|
||||
|
||||
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.size = delegate.size();
|
||||
this.dimension = delegate.dimension();
|
||||
docsWithField = new FixedBitSet(sortMap.size());
|
||||
vectors = new byte[sortMap.size()][];
|
||||
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(doc);
|
||||
docsWithField.set(newDocID);
|
||||
vectors[newDocID] = delegate.vectorValue().clone();
|
||||
}
|
||||
SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.delegate = delegate;
|
||||
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
|
||||
iteratorSupplier = iteratorSupplier(delegate, sortMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(docId + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectors[docId];
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
// ords are interpreted in the delegate's ord-space.
|
||||
return delegate.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
return delegate.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
return iteratorSupplier.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
if (target >= docsWithField.length()) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docId = docsWithField.nextSetBit(target);
|
||||
public FloatVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] target) {
|
||||
public DocIndexIterator iterator() {
|
||||
return iteratorSupplier.get();
|
||||
}
|
||||
}
|
||||
|
||||
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||
final ByteVectorValues delegate;
|
||||
final SortingIteratorSupplier iteratorSupplier;
|
||||
|
||||
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.delegate = delegate;
|
||||
// SortingValuesIterator consumes the iterator and records the docs and ord mapping
|
||||
iteratorSupplier = iteratorSupplier(delegate, sortMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue(int ord) throws IOException {
|
||||
return delegate.vectorValue(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return iteratorSupplier.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegate.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return iteratorSupplier.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues copy() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -181,8 +181,8 @@ public class FieldExistsQuery extends Query {
|
|||
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
||||
iterator =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> context.reader().getFloatVectorValues(field);
|
||||
case BYTE -> context.reader().getByteVectorValues(field);
|
||||
case FLOAT32 -> context.reader().getFloatVectorValues(field).iterator();
|
||||
case BYTE -> context.reader().getByteVectorValues(field).iterator();
|
||||
};
|
||||
} else if (fieldInfo.getDocValuesType()
|
||||
!= DocValuesType.NONE) { // the field indexes doc values
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.lucene.util.hnsw;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.HnswGraphProvider;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
@ -46,7 +46,7 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
|||
}
|
||||
|
||||
@Override
|
||||
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
|
||||
protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd)
|
||||
throws IOException {
|
||||
if (initReader == null) {
|
||||
return new HnswConcurrentMergeBuilder(
|
||||
|
@ -61,7 +61,7 @@ public class ConcurrentHnswMerger extends IncrementalHnswGraphMerger {
|
|||
|
||||
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
|
||||
BitSet initializedNodes = new FixedBitSet(maxOrd);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
|
||||
|
||||
return new HnswConcurrentMergeBuilder(
|
||||
taskExecutor,
|
||||
|
|
|
@ -18,8 +18,8 @@ package org.apache.lucene.util.hnsw;
|
|||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
|
||||
|
@ -45,12 +45,12 @@ public interface HnswGraphMerger {
|
|||
/**
|
||||
* Merge and produce the on heap graph
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @param mergedVectorValues view of the vectors in the merged segment
|
||||
* @param infoStream optional info stream to set to builder
|
||||
* @param maxOrd max number of vectors that will be added to the graph
|
||||
* @return merged graph
|
||||
* @throws IOException during merge
|
||||
*/
|
||||
OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd)
|
||||
OnHeapHnswGraph merge(KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd)
|
||||
throws IOException;
|
||||
}
|
||||
|
|
|
@ -25,9 +25,9 @@ import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
|||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.internal.hppc.IntIntHashMap;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
@ -108,12 +108,12 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
|
|||
* Builds a new HnswGraphBuilder using the biggest graph from the merge state as a starting point.
|
||||
* If no valid readers were added to the merge state, a new graph is created.
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @param mergedVectorValues vector values in the merged segment
|
||||
* @param maxOrd max num of vectors that will be merged into the graph
|
||||
* @return HnswGraphBuilder
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd)
|
||||
protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd)
|
||||
throws IOException {
|
||||
if (initReader == null) {
|
||||
return HnswGraphBuilder.create(
|
||||
|
@ -123,7 +123,7 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
|
|||
HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name);
|
||||
|
||||
BitSet initializedNodes = new FixedBitSet(maxOrd);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes);
|
||||
int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes);
|
||||
return InitializedHnswGraphBuilder.fromGraph(
|
||||
scorerSupplier,
|
||||
M,
|
||||
|
@ -137,8 +137,8 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
|
|||
|
||||
@Override
|
||||
public OnHeapHnswGraph merge(
|
||||
DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException {
|
||||
HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd);
|
||||
KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException {
|
||||
HnswBuilder builder = createBuilder(mergedVectorValues, maxOrd);
|
||||
builder.setInfoStream(infoStream);
|
||||
return builder.build(maxOrd);
|
||||
}
|
||||
|
@ -147,46 +147,45 @@ public class IncrementalHnswGraphMerger implements HnswGraphMerger {
|
|||
* Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors
|
||||
* in the newly merged segment.
|
||||
*
|
||||
* @param mergedVectorIterator iterator over the vectors in the merged segment
|
||||
* @param mergedVectorValues vector values in the merged segment
|
||||
* @param initializedNodes track what nodes have been initialized
|
||||
* @return the mapping from old ordinals to new ordinals
|
||||
* @throws IOException If an error occurs while reading from the merge state
|
||||
*/
|
||||
protected final int[] getNewOrdMapping(
|
||||
DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException {
|
||||
DocIdSetIterator initializerIterator = null;
|
||||
KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException {
|
||||
KnnVectorValues.DocIndexIterator initializerIterator = null;
|
||||
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name);
|
||||
case FLOAT32 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name);
|
||||
case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator();
|
||||
case FLOAT32 ->
|
||||
initializerIterator = initReader.getFloatVectorValues(fieldInfo.name).iterator();
|
||||
}
|
||||
|
||||
IntIntHashMap newIdToOldOrdinal = new IntIntHashMap(initGraphSize);
|
||||
int oldOrd = 0;
|
||||
int maxNewDocID = -1;
|
||||
for (int oldId = initializerIterator.nextDoc();
|
||||
oldId != NO_MORE_DOCS;
|
||||
oldId = initializerIterator.nextDoc()) {
|
||||
int newId = initDocMap.get(oldId);
|
||||
for (int docId = initializerIterator.nextDoc();
|
||||
docId != NO_MORE_DOCS;
|
||||
docId = initializerIterator.nextDoc()) {
|
||||
int newId = initDocMap.get(docId);
|
||||
maxNewDocID = Math.max(newId, maxNewDocID);
|
||||
newIdToOldOrdinal.put(newId, oldOrd);
|
||||
oldOrd++;
|
||||
newIdToOldOrdinal.put(newId, initializerIterator.index());
|
||||
}
|
||||
|
||||
if (maxNewDocID == -1) {
|
||||
return new int[0];
|
||||
}
|
||||
final int[] oldToNewOrdinalMap = new int[initGraphSize];
|
||||
int newOrd = 0;
|
||||
KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator();
|
||||
for (int newDocId = mergedVectorIterator.nextDoc();
|
||||
newDocId <= maxNewDocID;
|
||||
newDocId = mergedVectorIterator.nextDoc()) {
|
||||
int hashDocIndex = newIdToOldOrdinal.indexOf(newDocId);
|
||||
if (newIdToOldOrdinal.indexExists(hashDocIndex)) {
|
||||
int newOrd = mergedVectorIterator.index();
|
||||
initializedNodes.set(newOrd);
|
||||
oldToNewOrdinalMap[newIdToOldOrdinal.indexGet(hashDocIndex)] = newOrd;
|
||||
}
|
||||
newOrd++;
|
||||
}
|
||||
return oldToNewOrdinalMap;
|
||||
}
|
||||
|
|
|
@ -1,175 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
|
||||
* implementations of KNN search.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface RandomAccessVectorValues {
|
||||
|
||||
/** Return the number of vector values */
|
||||
int size();
|
||||
|
||||
/** Return the dimension of the returned vector values */
|
||||
int dimension();
|
||||
|
||||
/**
|
||||
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
|
||||
* access different values at once, to avoid overwriting the underlying vector returned.
|
||||
*/
|
||||
RandomAccessVectorValues copy() throws IOException;
|
||||
|
||||
/**
|
||||
* Returns a slice of the underlying {@link IndexInput} that contains the vector values if
|
||||
* available
|
||||
*/
|
||||
default IndexInput getSlice() {
|
||||
return null;
|
||||
}
|
||||
|
||||
/** Returns the byte length of the vector values. */
|
||||
int getVectorByteLength();
|
||||
|
||||
/**
|
||||
* Translates vector ordinal to the correct document ID. By default, this is an identity function.
|
||||
*
|
||||
* @param ord the vector ordinal
|
||||
* @return the document Id for that vector ordinal
|
||||
*/
|
||||
default int ordToDoc(int ord) {
|
||||
return ord;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the {@link Bits} representing live documents. By default, this is an identity function.
|
||||
*
|
||||
* @param acceptDocs the accept docs
|
||||
* @return the accept docs
|
||||
*/
|
||||
default Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
|
||||
/** Float vector values. */
|
||||
interface Floats extends RandomAccessVectorValues {
|
||||
@Override
|
||||
RandomAccessVectorValues.Floats copy() throws IOException;
|
||||
|
||||
/**
|
||||
* Return the vector value indexed at the given ordinal.
|
||||
*
|
||||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
||||
*/
|
||||
float[] vectorValue(int targetOrd) throws IOException;
|
||||
|
||||
/** Returns the vector byte length, defaults to dimension multiplied by float byte size */
|
||||
@Override
|
||||
default int getVectorByteLength() {
|
||||
return dimension() * Float.BYTES;
|
||||
}
|
||||
}
|
||||
|
||||
/** Byte vector values. */
|
||||
interface Bytes extends RandomAccessVectorValues {
|
||||
@Override
|
||||
RandomAccessVectorValues.Bytes copy() throws IOException;
|
||||
|
||||
/**
|
||||
* Return the vector value indexed at the given ordinal.
|
||||
*
|
||||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
||||
*/
|
||||
byte[] vectorValue(int targetOrd) throws IOException;
|
||||
|
||||
/** Returns the vector byte length, defaults to dimension multiplied by byte size */
|
||||
@Override
|
||||
default int getVectorByteLength() {
|
||||
return dimension() * Byte.BYTES;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link RandomAccessVectorValues.Floats} from a list of float arrays.
|
||||
*
|
||||
* @param vectors the list of float arrays
|
||||
* @param dim the dimension of the vectors
|
||||
* @return a {@link RandomAccessVectorValues.Floats} instance
|
||||
*/
|
||||
static RandomAccessVectorValues.Floats fromFloats(List<float[]> vectors, int dim) {
|
||||
return new RandomAccessVectorValues.Floats() {
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues.Floats copy() {
|
||||
return this;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a {@link RandomAccessVectorValues.Bytes} from a list of byte arrays.
|
||||
*
|
||||
* @param vectors the list of byte arrays
|
||||
* @param dim the dimension of the vectors
|
||||
* @return a {@link RandomAccessVectorValues.Bytes} instance
|
||||
*/
|
||||
static RandomAccessVectorValues.Bytes fromBytes(List<byte[]> vectors, int dim) {
|
||||
return new RandomAccessVectorValues.Bytes() {
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue(int targetOrd) {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues.Bytes copy() {
|
||||
return this;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
|
||||
|
@ -57,14 +58,14 @@ public interface RandomVectorScorer {
|
|||
|
||||
/** Creates a default scorer for random access vectors. */
|
||||
abstract class AbstractRandomVectorScorer implements RandomVectorScorer {
|
||||
private final RandomAccessVectorValues values;
|
||||
private final KnnVectorValues values;
|
||||
|
||||
/**
|
||||
* Creates a new scorer for the given vector values.
|
||||
*
|
||||
* @param values the vector values
|
||||
*/
|
||||
public AbstractRandomVectorScorer(RandomAccessVectorValues values) {
|
||||
public AbstractRandomVectorScorer(KnnVectorValues values) {
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,10 @@
|
|||
package org.apache.lucene.util.quantization;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
||||
/**
|
||||
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
|
||||
|
@ -27,31 +28,31 @@ import org.apache.lucene.search.VectorScorer;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class QuantizedByteVectorValues extends DocIdSetIterator {
|
||||
public abstract float getScoreCorrectionConstant() throws IOException;
|
||||
public abstract class QuantizedByteVectorValues extends ByteVectorValues implements HasIndexSlice {
|
||||
|
||||
public abstract byte[] vectorValue() throws IOException;
|
||||
|
||||
/** Return the dimension of the vectors */
|
||||
public abstract int dimension();
|
||||
|
||||
/**
|
||||
* Return the number of vectors for this field.
|
||||
*
|
||||
* @return the number of vectors returned by this iterator
|
||||
*/
|
||||
public abstract int size();
|
||||
|
||||
@Override
|
||||
public final long cost() {
|
||||
return size();
|
||||
public ScalarQuantizer getScalarQuantizer() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
public abstract float getScoreCorrectionConstant(int ord) throws IOException;
|
||||
|
||||
/**
|
||||
* Return a {@link VectorScorer} for the given query vector.
|
||||
*
|
||||
* @param query the query vector
|
||||
* @return a {@link VectorScorer} instance or null
|
||||
*/
|
||||
public abstract VectorScorer scorer(float[] query) throws IOException;
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public QuantizedByteVectorValues copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IndexInput getSlice() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.List;
|
|||
import java.util.Random;
|
||||
import java.util.stream.IntStream;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
@ -269,11 +270,12 @@ public class ScalarQuantizer {
|
|||
if (totalVectorCount == 0) {
|
||||
return new ScalarQuantizer(0f, 0f, bits);
|
||||
}
|
||||
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
|
||||
if (confidenceInterval == 1f) {
|
||||
float min = Float.POSITIVE_INFINITY;
|
||||
float max = Float.NEGATIVE_INFINITY;
|
||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
for (float v : floatVectorValues.vectorValue()) {
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS) {
|
||||
for (float v : floatVectorValues.vectorValue(iterator.index())) {
|
||||
min = Math.min(min, v);
|
||||
max = Math.max(max, v);
|
||||
}
|
||||
|
@ -289,8 +291,8 @@ public class ScalarQuantizer {
|
|||
if (totalVectorCount <= quantizationSampleSize) {
|
||||
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
|
||||
int i = 0;
|
||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
float[] vectorValue = floatVectorValues.vectorValue();
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS) {
|
||||
float[] vectorValue = floatVectorValues.vectorValue(iterator.index());
|
||||
System.arraycopy(
|
||||
vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length);
|
||||
i++;
|
||||
|
@ -311,11 +313,11 @@ public class ScalarQuantizer {
|
|||
for (int i : vectorsToTake) {
|
||||
while (index <= i) {
|
||||
// We cannot use `advance(docId)` as MergedVectorValues does not support it
|
||||
floatVectorValues.nextDoc();
|
||||
iterator.nextDoc();
|
||||
index++;
|
||||
}
|
||||
assert floatVectorValues.docID() != NO_MORE_DOCS;
|
||||
float[] vectorValue = floatVectorValues.vectorValue();
|
||||
assert iterator.docID() != NO_MORE_DOCS;
|
||||
float[] vectorValue = floatVectorValues.vectorValue(iterator.index());
|
||||
System.arraycopy(
|
||||
vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length);
|
||||
idx++;
|
||||
|
@ -353,11 +355,16 @@ public class ScalarQuantizer {
|
|||
/ (floatVectorValues.dimension() + 1),
|
||||
1 - 1f / (floatVectorValues.dimension() + 1)
|
||||
};
|
||||
KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator();
|
||||
if (totalVectorCount <= sampleSize) {
|
||||
int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount);
|
||||
int i = 0;
|
||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i);
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS) {
|
||||
gatherSample(
|
||||
floatVectorValues.vectorValue(iterator.index()),
|
||||
quantileGatheringScratch,
|
||||
sampledDocs,
|
||||
i);
|
||||
i++;
|
||||
if (i == scratchSize) {
|
||||
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
|
||||
|
@ -374,11 +381,15 @@ public class ScalarQuantizer {
|
|||
for (int i : vectorsToTake) {
|
||||
while (index <= i) {
|
||||
// We cannot use `advance(docId)` as MergedVectorValues does not support it
|
||||
floatVectorValues.nextDoc();
|
||||
iterator.nextDoc();
|
||||
index++;
|
||||
}
|
||||
assert floatVectorValues.docID() != NO_MORE_DOCS;
|
||||
gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx);
|
||||
assert iterator.docID() != NO_MORE_DOCS;
|
||||
gatherSample(
|
||||
floatVectorValues.vectorValue(iterator.index()),
|
||||
quantileGatheringScratch,
|
||||
sampledDocs,
|
||||
idx);
|
||||
idx++;
|
||||
if (idx == SCRATCH_SIZE) {
|
||||
extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum);
|
||||
|
@ -437,12 +448,7 @@ public class ScalarQuantizer {
|
|||
}
|
||||
|
||||
private static void gatherSample(
|
||||
FloatVectorValues floatVectorValues,
|
||||
float[] quantileGatheringScratch,
|
||||
List<float[]> sampledDocs,
|
||||
int i)
|
||||
throws IOException {
|
||||
float[] vectorValue = floatVectorValues.vectorValue();
|
||||
float[] vectorValue, float[] quantileGatheringScratch, List<float[]> sampledDocs, int i) {
|
||||
float[] copy = new float[vectorValue.length];
|
||||
System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length);
|
||||
sampledDocs.add(copy);
|
||||
|
|
|
@ -19,11 +19,11 @@ package org.apache.lucene.internal.vectorization;
|
|||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import java.util.Optional;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
||||
|
@ -39,10 +39,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
|||
* returned.
|
||||
*/
|
||||
public static Optional<Lucene99MemorySegmentByteVectorScorer> create(
|
||||
VectorSimilarityFunction type,
|
||||
IndexInput input,
|
||||
RandomAccessVectorValues values,
|
||||
byte[] queryVector) {
|
||||
VectorSimilarityFunction type, IndexInput input, KnnVectorValues values, byte[] queryVector) {
|
||||
input = FilterIndexInput.unwrapOnlyTest(input);
|
||||
if (!(input instanceof MemorySegmentAccessInput msInput)) {
|
||||
return Optional.empty();
|
||||
|
@ -58,7 +55,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
|||
}
|
||||
|
||||
Lucene99MemorySegmentByteVectorScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) {
|
||||
MemorySegmentAccessInput input, KnnVectorValues values, byte[] queryVector) {
|
||||
super(values);
|
||||
this.input = input;
|
||||
this.vectorByteSize = values.getVectorByteLength();
|
||||
|
@ -92,7 +89,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
|||
}
|
||||
|
||||
static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
CosineScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
|
@ -105,8 +102,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
|||
}
|
||||
|
||||
static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
DotProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
DotProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
|
@ -120,7 +116,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
|||
}
|
||||
|
||||
static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
EuclideanScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
|
@ -133,8 +129,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer
|
|||
}
|
||||
|
||||
static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer {
|
||||
MaxInnerProductScorer(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) {
|
||||
MaxInnerProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) {
|
||||
super(input, values, query);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,11 +19,11 @@ package org.apache.lucene.internal.vectorization;
|
|||
import java.io.IOException;
|
||||
import java.lang.foreign.MemorySegment;
|
||||
import java.util.Optional;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.FilterIndexInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.MemorySegmentAccessInput;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
|
@ -33,7 +33,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
final int vectorByteSize;
|
||||
final int maxOrd;
|
||||
final MemorySegmentAccessInput input;
|
||||
final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds
|
||||
final KnnVectorValues values; // to support ordToDoc/getAcceptOrds
|
||||
byte[] scratch1, scratch2;
|
||||
|
||||
/**
|
||||
|
@ -41,7 +41,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
* optional is returned.
|
||||
*/
|
||||
static Optional<RandomVectorScorerSupplier> create(
|
||||
VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) {
|
||||
VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) {
|
||||
input = FilterIndexInput.unwrapOnlyTest(input);
|
||||
if (!(input instanceof MemorySegmentAccessInput msInput)) {
|
||||
return Optional.empty();
|
||||
|
@ -56,7 +56,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
}
|
||||
|
||||
Lucene99MemorySegmentByteVectorScorerSupplier(
|
||||
MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
MemorySegmentAccessInput input, KnnVectorValues values) {
|
||||
this.input = input;
|
||||
this.values = values;
|
||||
this.vectorByteSize = values.getVectorByteLength();
|
||||
|
@ -103,7 +103,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
|
||||
static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
CosineSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
|
@ -128,7 +128,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
|
||||
static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
DotProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
|
@ -155,7 +155,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
|
||||
static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
EuclideanSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
|
@ -181,7 +181,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier
|
|||
|
||||
static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier {
|
||||
|
||||
MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) {
|
||||
MaxInnerProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) {
|
||||
super(input, values);
|
||||
}
|
||||
|
||||
|
|
|
@ -19,11 +19,12 @@ package org.apache.lucene.internal.vectorization;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
|
||||
public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer {
|
||||
|
||||
|
@ -38,15 +39,15 @@ public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer
|
|||
|
||||
@Override
|
||||
public RandomVectorScorerSupplier getRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues)
|
||||
throws IOException {
|
||||
VectorSimilarityFunction similarityType, KnnVectorValues vectorValues) throws IOException {
|
||||
// a quantized values here is a wrapping or delegation issue
|
||||
assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues);
|
||||
assert !(vectorValues instanceof QuantizedByteVectorValues);
|
||||
// currently only supports binary vectors
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
|
||||
if (vectorValues instanceof HasIndexSlice byteVectorValues
|
||||
&& byteVectorValues.getSlice() != null) {
|
||||
var scorer =
|
||||
Lucene99MemorySegmentByteVectorScorerSupplier.create(
|
||||
similarityType, vectorValues.getSlice(), vectorValues);
|
||||
similarityType, byteVectorValues.getSlice(), vectorValues);
|
||||
if (scorer.isPresent()) {
|
||||
return scorer.get();
|
||||
}
|
||||
|
@ -56,9 +57,7 @@ public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityType,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
float[] target)
|
||||
VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, float[] target)
|
||||
throws IOException {
|
||||
// currently only supports binary vectors, so always delegate
|
||||
return delegate.getRandomVectorScorer(similarityType, vectorValues, target);
|
||||
|
@ -66,17 +65,16 @@ public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer
|
|||
|
||||
@Override
|
||||
public RandomVectorScorer getRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityType,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
byte[] queryVector)
|
||||
VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, byte[] queryVector)
|
||||
throws IOException {
|
||||
checkDimensions(queryVector.length, vectorValues.dimension());
|
||||
// a quantized values here is a wrapping or delegation issue
|
||||
assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues);
|
||||
if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) {
|
||||
assert !(vectorValues instanceof QuantizedByteVectorValues);
|
||||
if (vectorValues instanceof HasIndexSlice byteVectorValues
|
||||
&& byteVectorValues.getSlice() != null) {
|
||||
var scorer =
|
||||
Lucene99MemorySegmentByteVectorScorer.create(
|
||||
similarityType, vectorValues.getSlice(), vectorValues, queryVector);
|
||||
similarityType, byteVectorValues.getSlice(), vectorValues, queryVector);
|
||||
if (scorer.isPresent()) {
|
||||
return scorer.get();
|
||||
}
|
||||
|
|
|
@ -35,6 +35,8 @@ import java.util.concurrent.atomic.AtomicInteger;
|
|||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
|
@ -42,7 +44,6 @@ import org.apache.lucene.store.IndexInput;
|
|||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.hamcrest.Matcher;
|
||||
import org.hamcrest.MatcherAssert;
|
||||
|
||||
|
@ -174,13 +175,13 @@ public class TestFlatVectorScorer extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
RandomAccessVectorValues byteVectorValues(
|
||||
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
|
||||
ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim)
|
||||
throws IOException {
|
||||
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
dims, size, in.slice("byteValues", 0, in.length()), dims, flatVectorsScorer, sim);
|
||||
}
|
||||
|
||||
RandomAccessVectorValues floatVectorValues(
|
||||
FloatVectorValues floatVectorValues(
|
||||
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
|
||||
return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||
dims,
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static java.lang.String.format;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.oneOf;
|
||||
|
||||
|
@ -312,14 +311,13 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
assertNotNull(hnswReader.getQuantizationState("f"));
|
||||
QuantizedByteVectorValues quantizedByteVectorValues =
|
||||
hnswReader.getQuantizedVectorValues("f");
|
||||
int docId = -1;
|
||||
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
byte[] vector = quantizedByteVectorValues.vectorValue();
|
||||
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
|
||||
for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) {
|
||||
byte[] vector = quantizedByteVectorValues.vectorValue(ord);
|
||||
float offset = quantizedByteVectorValues.getScoreCorrectionConstant(ord);
|
||||
for (int i = 0; i < dim; i++) {
|
||||
assertEquals(vector[i], expectedVectors[docId][i]);
|
||||
assertEquals(vector[i], expectedVectors[ord][i]);
|
||||
}
|
||||
assertEquals(offset, expectedCorrections[docId], 0.00001f);
|
||||
assertEquals(offset, expectedCorrections[ord], 0.00001f);
|
||||
}
|
||||
} else {
|
||||
fail("reader is not Lucene99HnswVectorsReader");
|
||||
|
|
|
@ -46,7 +46,7 @@ import org.apache.lucene.store.IndexOutput;
|
|||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
|
||||
public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
|
||||
|
@ -100,8 +100,8 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
|
|||
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
|
||||
Lucene99ScalarQuantizedVectorScorer scorer =
|
||||
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
|
||||
RandomAccessQuantizedByteVectorValues values =
|
||||
new RandomAccessQuantizedByteVectorValues() {
|
||||
QuantizedByteVectorValues values =
|
||||
new QuantizedByteVectorValues() {
|
||||
@Override
|
||||
public int dimension() {
|
||||
return 32;
|
||||
|
@ -128,7 +128,7 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessQuantizedByteVectorValues copy() throws IOException {
|
||||
public QuantizedByteVectorValues copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ import org.apache.lucene.index.DirectoryReader;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -173,9 +174,10 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
|
|||
QuantizedByteVectorValues quantizedByteVectorValues =
|
||||
quantizedReader.getQuantizedVectorValues("f");
|
||||
int docId = -1;
|
||||
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
byte[] vector = quantizedByteVectorValues.vectorValue();
|
||||
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
|
||||
KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator();
|
||||
for (docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) {
|
||||
byte[] vector = quantizedByteVectorValues.vectorValue(iter.index());
|
||||
float offset = quantizedByteVectorValues.getScoreCorrectionConstant(iter.index());
|
||||
for (int i = 0; i < dim; i++) {
|
||||
assertEquals(vector[i], expectedVectors[docId][i]);
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ package org.apache.lucene.document;
|
|||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.StringReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
|
@ -27,6 +28,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
|
@ -713,17 +715,21 @@ public class TestField extends LuceneTestCase {
|
|||
try (IndexReader r = DirectoryReader.open(w)) {
|
||||
ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary");
|
||||
assertEquals(1, binary.size());
|
||||
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
assertNotNull(binary.vectorValue());
|
||||
assertArrayEquals(b, binary.vectorValue());
|
||||
assertEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = binary.iterator();
|
||||
assertNotEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
assertNotNull(binary.vectorValue(0));
|
||||
assertArrayEquals(b, binary.vectorValue(0));
|
||||
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
expectThrows(IOException.class, () -> binary.vectorValue(1));
|
||||
|
||||
FloatVectorValues floatValues = r.leaves().get(0).reader().getFloatVectorValues("float");
|
||||
assertEquals(1, floatValues.size());
|
||||
assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc());
|
||||
assertEquals(vector.length, floatValues.vectorValue().length);
|
||||
assertEquals(vector[0], floatValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, floatValues.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator1 = floatValues.iterator();
|
||||
assertNotEquals(NO_MORE_DOCS, iterator1.nextDoc());
|
||||
assertEquals(vector.length, floatValues.vectorValue(0).length);
|
||||
assertEquals(vector[0], floatValues.vectorValue(0)[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, iterator1.nextDoc());
|
||||
expectThrows(IOException.class, () -> floatValues.vectorValue(1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -459,8 +459,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
expectThrows(
|
||||
ExitingReaderException.class,
|
||||
() -> {
|
||||
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
|
||||
scanAndRetrieve(leaf, iter);
|
||||
KnnVectorValues values = leaf.getFloatVectorValues("vector");
|
||||
scanAndRetrieve(leaf, values);
|
||||
});
|
||||
|
||||
expectThrows(
|
||||
|
@ -473,8 +473,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
leaf.getLiveDocs(),
|
||||
Integer.MAX_VALUE));
|
||||
} else {
|
||||
DocIdSetIterator iter = leaf.getFloatVectorValues("vector");
|
||||
scanAndRetrieve(leaf, iter);
|
||||
KnnVectorValues values = leaf.getFloatVectorValues("vector");
|
||||
scanAndRetrieve(leaf, values);
|
||||
|
||||
leaf.searchNearestVectors(
|
||||
"vector",
|
||||
|
@ -534,8 +534,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
expectThrows(
|
||||
ExitingReaderException.class,
|
||||
() -> {
|
||||
DocIdSetIterator iter = leaf.getByteVectorValues("vector");
|
||||
scanAndRetrieve(leaf, iter);
|
||||
KnnVectorValues values = leaf.getByteVectorValues("vector");
|
||||
scanAndRetrieve(leaf, values);
|
||||
});
|
||||
|
||||
expectThrows(
|
||||
|
@ -549,8 +549,8 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
Integer.MAX_VALUE));
|
||||
|
||||
} else {
|
||||
DocIdSetIterator iter = leaf.getByteVectorValues("vector");
|
||||
scanAndRetrieve(leaf, iter);
|
||||
KnnVectorValues values = leaf.getByteVectorValues("vector");
|
||||
scanAndRetrieve(leaf, values);
|
||||
|
||||
leaf.searchNearestVectors(
|
||||
"vector",
|
||||
|
@ -564,20 +564,24 @@ public class TestExitableDirectoryReader extends LuceneTestCase {
|
|||
directory.close();
|
||||
}
|
||||
|
||||
private static void scanAndRetrieve(LeafReader leaf, DocIdSetIterator iter) throws IOException {
|
||||
private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) throws IOException {
|
||||
KnnVectorValues.DocIndexIterator iter = values.iterator();
|
||||
for (iter.nextDoc();
|
||||
iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) {
|
||||
final int nextDocId = iter.docID() + 1;
|
||||
int docId = iter.docID();
|
||||
if (docId >= leaf.maxDoc()) {
|
||||
break;
|
||||
}
|
||||
final int nextDocId = docId + 1;
|
||||
if (random().nextBoolean() && nextDocId < leaf.maxDoc()) {
|
||||
iter.advance(nextDocId);
|
||||
} else {
|
||||
iter.nextDoc();
|
||||
}
|
||||
|
||||
if (random().nextBoolean()
|
||||
&& iter.docID() != DocIdSetIterator.NO_MORE_DOCS
|
||||
&& iter instanceof FloatVectorValues) {
|
||||
((FloatVectorValues) iter).vectorValue();
|
||||
&& values instanceof FloatVectorValues) {
|
||||
((FloatVectorValues) values).vectorValue(iter.index());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -413,11 +413,13 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
// stored vector values are the same as original
|
||||
int nextDocWithVectors = 0;
|
||||
StoredFields storedFields = reader.storedFields();
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
for (int i = 0; i < reader.maxDoc(); i++) {
|
||||
nextDocWithVectors = vectorValues.advance(i);
|
||||
nextDocWithVectors = iterator.advance(i);
|
||||
while (i < nextDocWithVectors && i < reader.maxDoc()) {
|
||||
int id = Integer.parseInt(storedFields.document(i).get("id"));
|
||||
assertNull("document " + id + " has no vector, but was expected to", values[id]);
|
||||
assertNull(
|
||||
"document " + id + ", expected to have no vector, does have one", values[id]);
|
||||
++i;
|
||||
}
|
||||
if (nextDocWithVectors == NO_MORE_DOCS) {
|
||||
|
@ -425,7 +427,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
int id = Integer.parseInt(storedFields.document(i).get("id"));
|
||||
// documents with KnnGraphValues have the expected vectors
|
||||
float[] scratch = vectorValues.vectorValue();
|
||||
float[] scratch = vectorValues.vectorValue(iterator.index());
|
||||
assertArrayEquals(
|
||||
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
|
||||
values[id],
|
||||
|
@ -435,9 +437,9 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
}
|
||||
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()
|
||||
if (nextDocWithVectors != NO_MORE_DOCS) {
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
} else {
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.docID());
|
||||
assertEquals(NO_MORE_DOCS, iterator.docID());
|
||||
}
|
||||
|
||||
// assert graph values:
|
||||
|
|
|
@ -242,6 +242,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
|
|||
NumericDocValues ids = leaf.getNumericDocValues("id");
|
||||
long prevValue = -1;
|
||||
boolean usingAltIds = false;
|
||||
KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator();
|
||||
for (int i = 0; i < actualNumDocs; i++) {
|
||||
int idNext = ids.nextDoc();
|
||||
if (idNext == DocIdSetIterator.NO_MORE_DOCS) {
|
||||
|
@ -262,7 +263,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
|
|||
assertTrue(sorted_numeric_dv.advanceExact(idNext));
|
||||
assertTrue(sorted_set_dv.advanceExact(idNext));
|
||||
assertTrue(binary_sorted_dv.advanceExact(idNext));
|
||||
assertEquals(idNext, vectorValues.advance(idNext));
|
||||
assertEquals(idNext, valuesIterator.advance(idNext));
|
||||
assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue());
|
||||
assertEquals(
|
||||
new BytesRef(ids.longValue() + ""),
|
||||
|
@ -274,7 +275,7 @@ public class TestSortingCodecReader extends LuceneTestCase {
|
|||
assertEquals(1, sorted_numeric_dv.docValueCount());
|
||||
assertEquals(ids.longValue(), sorted_numeric_dv.nextValue());
|
||||
|
||||
float[] vectorValue = vectorValues.vectorValue();
|
||||
float[] vectorValue = vectorValues.vectorValue(valuesIterator.index());
|
||||
assertEquals(1, vectorValue.length);
|
||||
assertEquals((float) ids.longValue(), vectorValue[0], 0.001f);
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ import java.util.stream.IntStream;
|
|||
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
|
@ -47,7 +48,6 @@ import org.apache.lucene.store.IndexOutput;
|
|||
import org.apache.lucene.store.MMapDirectory;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.NamedThreadFactory;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.junit.BeforeClass;
|
||||
|
||||
|
@ -329,8 +329,8 @@ public class TestVectorScorer extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
RandomAccessVectorValues vectorValues(
|
||||
int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
|
||||
KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim)
|
||||
throws IOException {
|
||||
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim);
|
||||
}
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.lucene.index.FilterLeafReader;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.QueryTimeout;
|
||||
|
@ -740,7 +741,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
LeafReader leafReader = getOnlyLeafReader(reader);
|
||||
FieldInfo fi = leafReader.getFieldInfos().fieldInfo("field");
|
||||
assertNotNull(fi);
|
||||
DocIdSetIterator vectorValues;
|
||||
KnnVectorValues vectorValues;
|
||||
switch (fi.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
vectorValues = leafReader.getByteVectorValues("field");
|
||||
|
@ -752,7 +753,7 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
|
|||
throw new AssertionError();
|
||||
}
|
||||
assertNotNull(vectorValues);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -113,7 +113,7 @@ public class TestTimeLimitingBulkScorer extends LuceneTestCase {
|
|||
private static QueryTimeout countingQueryTimeout(int timeallowed) {
|
||||
|
||||
return new QueryTimeout() {
|
||||
static int counter = 0;
|
||||
int counter = 0;
|
||||
|
||||
@Override
|
||||
public boolean shouldExit() {
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final T[] denseValues;
|
||||
protected final T[] values;
|
||||
protected final int numVectors;
|
||||
protected final BytesRef binaryValue;
|
||||
|
||||
protected int pos = -1;
|
||||
|
||||
AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) {
|
||||
this.dimension = dimension;
|
||||
this.values = values;
|
||||
this.denseValues = denseValues;
|
||||
// used by tests that build a graph from bytes rather than floats
|
||||
binaryValue = new BytesRef(dimension);
|
||||
binaryValue.length = dimension;
|
||||
this.numVectors = numVectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numVectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
public T vectorValue(int targetOrd) {
|
||||
return denseValues[targetOrd];
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract AbstractMockVectorValues<T> copy();
|
||||
|
||||
public abstract T vectorValue() throws IOException;
|
||||
|
||||
private boolean seek(int target) {
|
||||
if (target >= 0 && target < values.length && values[target] != null) {
|
||||
pos = target;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public int docID() {
|
||||
return pos;
|
||||
}
|
||||
|
||||
public int nextDoc() {
|
||||
return advance(pos + 1);
|
||||
}
|
||||
|
||||
public int advance(int target) {
|
||||
while (++pos < values.length) {
|
||||
if (seek(pos)) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
}
|
|
@ -56,6 +56,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
|
@ -97,33 +98,28 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
abstract T randomVector(int dim);
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(int size, int dimension);
|
||||
abstract KnnVectorValues vectorValues(int size, int dimension);
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(float[][] values);
|
||||
abstract KnnVectorValues vectorValues(float[][] values);
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException;
|
||||
abstract KnnVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException;
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(
|
||||
int size,
|
||||
int dimension,
|
||||
AbstractMockVectorValues<T> pregeneratedVectorValues,
|
||||
int pregeneratedOffset);
|
||||
abstract KnnVectorValues vectorValues(
|
||||
int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset);
|
||||
|
||||
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
|
||||
|
||||
abstract RandomAccessVectorValues circularVectorValues(int nDoc);
|
||||
abstract KnnVectorValues circularVectorValues(int nDoc);
|
||||
|
||||
abstract T getTargetVector();
|
||||
|
||||
protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors)
|
||||
protected RandomVectorScorerSupplier buildScorerSupplier(KnnVectorValues vectors)
|
||||
throws IOException {
|
||||
return flatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectors);
|
||||
}
|
||||
|
||||
protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query)
|
||||
throws IOException {
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
protected RandomVectorScorer buildScorer(KnnVectorValues vectors, T query) throws IOException {
|
||||
KnnVectorValues vectorsCopy = vectors.copy();
|
||||
return switch (getVectorEncoding()) {
|
||||
case BYTE ->
|
||||
flatVectorScorer.getRandomVectorScorer(similarityFunction, vectorsCopy, (byte[]) query);
|
||||
|
@ -134,6 +130,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
|
||||
// Tests writing segments of various sizes and merging to ensure there are no errors
|
||||
// in the HNSW graph merging logic.
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testRandomReadWriteAndMerge() throws IOException {
|
||||
int dim = random().nextInt(100) + 1;
|
||||
int[] segmentSizes =
|
||||
|
@ -148,7 +145,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int M = random().nextInt(4) + 2;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(numVectors, dim);
|
||||
KnnVectorValues vectors = vectorValues(numVectors, dim);
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
|
||||
try (Directory dir = newDirectory()) {
|
||||
|
@ -173,7 +170,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
for (int i = 0; i < segmentSizes.length; i++) {
|
||||
int size = segmentSizes[i];
|
||||
while (vectors.nextDoc() < size) {
|
||||
for (int ord = 0; ord < size; ord++) {
|
||||
if (isSparse[i] && random().nextBoolean()) {
|
||||
int d = random().nextInt(10) + 1;
|
||||
for (int j = 0; j < d; j++) {
|
||||
|
@ -182,8 +179,24 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(knnVectorField("field", vectors.vectorValue(), similarityFunction));
|
||||
doc.add(new StringField("id", Integer.toString(vectors.docID()), Field.Store.NO));
|
||||
switch (vectors.getEncoding()) {
|
||||
case BYTE -> {
|
||||
doc.add(
|
||||
knnVectorField(
|
||||
"field",
|
||||
(T) ((ByteVectorValues) vectors).vectorValue(ord),
|
||||
similarityFunction));
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
doc.add(
|
||||
knnVectorField(
|
||||
"field",
|
||||
(T) ((FloatVectorValues) vectors).vectorValue(ord),
|
||||
similarityFunction));
|
||||
}
|
||||
}
|
||||
;
|
||||
doc.add(new StringField("id", Integer.toString(vectors.ordToDoc(ord)), Field.Store.NO));
|
||||
iw.addDocument(doc);
|
||||
}
|
||||
iw.commit();
|
||||
|
@ -199,13 +212,26 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
|
||||
KnnVectorValues values = vectorValues(ctx.reader(), "field");
|
||||
assertEquals(dim, values.dimension());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private T vectorValue(KnnVectorValues vectors, int ord) throws IOException {
|
||||
switch (vectors.getEncoding()) {
|
||||
case BYTE -> {
|
||||
return (T) ((ByteVectorValues) vectors).vectorValue(ord);
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
return (T) ((FloatVectorValues) vectors).vectorValue(ord);
|
||||
}
|
||||
}
|
||||
throw new AssertionError("unknown encoding " + vectors.getEncoding());
|
||||
}
|
||||
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
public void testReadWrite() throws IOException {
|
||||
int dim = random().nextInt(100) + 1;
|
||||
|
@ -213,8 +239,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int M = random().nextInt(4) + 2;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
|
||||
AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
|
||||
KnnVectorValues vectors = vectorValues(nDoc, dim);
|
||||
KnnVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed);
|
||||
HnswGraph hnsw = builder.build(vectors.size());
|
||||
|
@ -242,15 +268,16 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
});
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||
while (v2.nextDoc() != NO_MORE_DOCS) {
|
||||
while (indexedDoc < v2.docID()) {
|
||||
KnnVectorValues.DocIndexIterator it2 = v2.iterator();
|
||||
while (it2.nextDoc() != NO_MORE_DOCS) {
|
||||
while (indexedDoc < it2.docID()) {
|
||||
// increment docId in the index by adding empty documents
|
||||
iw.addDocument(new Document());
|
||||
indexedDoc++;
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction));
|
||||
doc.add(new StoredField("id", v2.docID()));
|
||||
doc.add(knnVectorField("field", vectorValue(v2, it2.index()), similarityFunction));
|
||||
doc.add(new StoredField("id", it2.docID()));
|
||||
iw.addDocument(doc);
|
||||
nVec++;
|
||||
indexedDoc++;
|
||||
|
@ -258,7 +285,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
|
||||
KnnVectorValues values = vectorValues(ctx.reader(), "field");
|
||||
assertEquals(dim, values.dimension());
|
||||
assertEquals(nVec, values.size());
|
||||
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
||||
|
@ -280,7 +307,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||
int dim = random().nextInt(10) + 3;
|
||||
int nDoc = random().nextInt(200) + 100;
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
|
||||
KnnVectorValues vectors = vectorValues(nDoc, dim);
|
||||
|
||||
int M = random().nextInt(10) + 5;
|
||||
int beamWidth = random().nextInt(10) + 10;
|
||||
|
@ -323,15 +350,15 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int indexedDoc = 0;
|
||||
try (IndexWriter iw = new IndexWriter(dir, iwc);
|
||||
IndexWriter iw2 = new IndexWriter(dir2, iwc2)) {
|
||||
while (vectors.nextDoc() != NO_MORE_DOCS) {
|
||||
while (indexedDoc < vectors.docID()) {
|
||||
for (int ord = 0; ord < vectors.size(); ord++) {
|
||||
while (indexedDoc < vectors.ordToDoc(ord)) {
|
||||
// increment docId in the index by adding empty documents
|
||||
iw.addDocument(new Document());
|
||||
indexedDoc++;
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction));
|
||||
doc.add(new StoredField("id", vectors.docID()));
|
||||
doc.add(knnVectorField("vector", vectorValue(vectors, ord), similarityFunction));
|
||||
doc.add(new StoredField("id", vectors.ordToDoc(ord)));
|
||||
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
|
||||
iw.addDocument(doc);
|
||||
iw2.addDocument(doc);
|
||||
|
@ -461,7 +488,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
public void testAknnDiverse() throws IOException {
|
||||
int nDoc = 100;
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
|
||||
KnnVectorValues vectors = circularVectorValues(nDoc);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.size());
|
||||
|
@ -493,7 +520,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testSearchWithAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
|
||||
KnnVectorValues vectors = circularVectorValues(nDoc);
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
|
||||
|
@ -518,7 +545,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
@SuppressWarnings("unchecked")
|
||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
|
||||
KnnVectorValues vectors = circularVectorValues(nDoc);
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
|
||||
|
@ -552,13 +579,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int dim = atLeast(10);
|
||||
long seed = random().nextLong();
|
||||
|
||||
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
|
||||
KnnVectorValues initializerVectors = vectorValues(initializerSize, dim);
|
||||
RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors);
|
||||
HnswGraphBuilder initializerBuilder =
|
||||
HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed);
|
||||
|
||||
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
|
||||
AbstractMockVectorValues<T> finalVectorValues =
|
||||
KnnVectorValues finalVectorValues =
|
||||
vectorValues(totalSize, dim, initializerVectors, docIdOffset);
|
||||
int[] initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
|
||||
|
@ -598,13 +625,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int dim = atLeast(10);
|
||||
long seed = random().nextLong();
|
||||
|
||||
AbstractMockVectorValues<T> initializerVectors = vectorValues(initializerSize, dim);
|
||||
KnnVectorValues initializerVectors = vectorValues(initializerSize, dim);
|
||||
RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors);
|
||||
HnswGraphBuilder initializerBuilder =
|
||||
HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed);
|
||||
|
||||
OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size());
|
||||
AbstractMockVectorValues<T> finalVectorValues =
|
||||
KnnVectorValues finalVectorValues =
|
||||
vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset);
|
||||
int[] initializerOrdMap =
|
||||
createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset);
|
||||
|
@ -688,19 +715,17 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
private int[] createOffsetOrdinalMap(
|
||||
int docIdSize, AbstractMockVectorValues<T> totalVectorValues, int docIdOffset) {
|
||||
int docIdSize, KnnVectorValues totalVectorValues, int docIdOffset) throws IOException {
|
||||
// Compute the offset for the ordinal map to be the number of non-null vectors in the total
|
||||
// vector values
|
||||
// before the docIdOffset
|
||||
// vector values before the docIdOffset
|
||||
int ordinalOffset = 0;
|
||||
while (totalVectorValues.nextDoc() < docIdOffset) {
|
||||
KnnVectorValues.DocIndexIterator it = totalVectorValues.iterator();
|
||||
while (it.nextDoc() < docIdOffset) {
|
||||
ordinalOffset++;
|
||||
}
|
||||
int[] offsetOrdinalMap = new int[docIdSize];
|
||||
|
||||
for (int curr = 0;
|
||||
totalVectorValues.docID() < docIdOffset + docIdSize;
|
||||
totalVectorValues.nextDoc()) {
|
||||
for (int curr = 0; it.docID() < docIdOffset + docIdSize; it.nextDoc()) {
|
||||
offsetOrdinalMap[curr] = ordinalOffset + curr++;
|
||||
}
|
||||
|
||||
|
@ -711,7 +736,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
RandomAccessVectorValues vectors = circularVectorValues(nDoc);
|
||||
KnnVectorValues vectors = circularVectorValues(nDoc);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.size());
|
||||
|
@ -746,7 +771,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
int M = randomIntBetween(4, 96);
|
||||
|
||||
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
RandomAccessVectorValues vectors = vectorValues(size, dim);
|
||||
KnnVectorValues vectors = vectorValues(size, dim);
|
||||
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder =
|
||||
|
@ -771,7 +796,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
unitVector2d(0.77),
|
||||
unitVector2d(0.6)
|
||||
};
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||
KnnVectorValues vectors = vectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 2, 10, random().nextInt());
|
||||
|
@ -825,7 +850,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
{10, 0, 0},
|
||||
{0, 4, 0}
|
||||
};
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||
KnnVectorValues vectors = vectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt());
|
||||
|
@ -855,7 +880,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
{0, 0, 20},
|
||||
{0, 9, 0}
|
||||
};
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||
KnnVectorValues vectors = vectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt());
|
||||
|
@ -891,7 +916,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
public void testRandom() throws IOException {
|
||||
int size = atLeast(100);
|
||||
int dim = atLeast(10);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||
KnnVectorValues vectors = vectorValues(size, dim);
|
||||
int topK = 5;
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong());
|
||||
|
@ -908,15 +933,13 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
TopDocs topDocs = actual.topDocs();
|
||||
NeighborQueue expected = new NeighborQueue(topK, false);
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
if (vectorValue(vectors, j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
if (getVectorEncoding() == VectorEncoding.BYTE) {
|
||||
assert query instanceof byte[];
|
||||
expected.add(
|
||||
j, similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(j)));
|
||||
j, similarityFunction.compare((byte[]) query, (byte[]) vectorValue(vectors, j)));
|
||||
} else {
|
||||
assert query instanceof float[];
|
||||
expected.add(
|
||||
j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j)));
|
||||
j, similarityFunction.compare((float[]) query, (float[]) vectorValue(vectors, j)));
|
||||
}
|
||||
if (expected.size() > topK) {
|
||||
expected.pop();
|
||||
|
@ -940,7 +963,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
throws IOException, ExecutionException, InterruptedException, TimeoutException {
|
||||
int size = atLeast(100);
|
||||
int dim = atLeast(10);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||
KnnVectorValues vectors = vectorValues(size, dim);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.size());
|
||||
|
@ -1004,7 +1027,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
public void testConcurrentMergeBuilder() throws IOException {
|
||||
int size = atLeast(1000);
|
||||
int dim = atLeast(10);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||
KnnVectorValues vectors = vectorValues(size, dim);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge"));
|
||||
TaskExecutor taskExecutor = new TaskExecutor(exec);
|
||||
|
@ -1033,7 +1056,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
// Search for a large number of results
|
||||
int topK = size - 1;
|
||||
|
||||
AbstractMockVectorValues<T> docVectors = vectorValues(size, dim);
|
||||
KnnVectorValues docVectors = vectorValues(size, dim);
|
||||
HnswGraph graph =
|
||||
HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong())
|
||||
.build(size);
|
||||
|
@ -1047,8 +1070,8 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
};
|
||||
|
||||
AbstractMockVectorValues<T> queryVectors = vectorValues(1, dim);
|
||||
RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0));
|
||||
KnnVectorValues queryVectors = vectorValues(1, dim);
|
||||
RandomVectorScorer queryScorer = buildScorer(docVectors, vectorValue(queryVectors, 0));
|
||||
|
||||
KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE);
|
||||
HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null);
|
||||
|
@ -1076,8 +1099,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
static class CircularFloatVectorValues extends FloatVectorValues
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
static class CircularFloatVectorValues extends FloatVectorValues {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
|
||||
|
@ -1103,22 +1125,18 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
if (target >= 0 && target < size) {
|
||||
doc = target;
|
||||
|
@ -1140,8 +1158,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
}
|
||||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
static class CircularByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues.Bytes {
|
||||
static class CircularByteVectorValues extends ByteVectorValues {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
private final byte[] bValue;
|
||||
|
@ -1169,22 +1186,18 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
if (target >= 0 && target < size) {
|
||||
doc = target;
|
||||
|
@ -1227,27 +1240,25 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return neighbors;
|
||||
}
|
||||
|
||||
void assertVectorsEqual(AbstractMockVectorValues<T> u, AbstractMockVectorValues<T> v)
|
||||
throws IOException {
|
||||
void assertVectorsEqual(KnnVectorValues u, KnnVectorValues v) throws IOException {
|
||||
int uDoc, vDoc;
|
||||
while (true) {
|
||||
uDoc = u.nextDoc();
|
||||
vDoc = v.nextDoc();
|
||||
assertEquals(u.size(), v.size());
|
||||
for (int ord = 0; ord < u.size(); ord++) {
|
||||
uDoc = u.ordToDoc(ord);
|
||||
vDoc = v.ordToDoc(ord);
|
||||
assertEquals(uDoc, vDoc);
|
||||
if (uDoc == NO_MORE_DOCS) {
|
||||
break;
|
||||
}
|
||||
assertNotEquals(NO_MORE_DOCS, uDoc);
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE ->
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
(byte[]) u.vectorValue(),
|
||||
(byte[]) v.vectorValue());
|
||||
(byte[]) vectorValue(u, ord),
|
||||
(byte[]) vectorValue(v, ord));
|
||||
case FLOAT32 ->
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
(float[]) u.vectorValue(),
|
||||
(float[]) v.vectorValue(),
|
||||
(float[]) vectorValue(u, ord),
|
||||
(float[]) vectorValue(v, ord),
|
||||
1e-4f);
|
||||
default ->
|
||||
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
|
||||
|
|
|
@ -17,11 +17,17 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
|
||||
implements RandomAccessVectorValues.Bytes {
|
||||
class MockByteVectorValues extends ByteVectorValues {
|
||||
private final int dimension;
|
||||
private final byte[][] denseValues;
|
||||
protected final byte[][] values;
|
||||
private final int numVectors;
|
||||
private final BytesRef binaryValue;
|
||||
private final byte[] scratch;
|
||||
|
||||
static MockByteVectorValues fromValues(byte[][] values) {
|
||||
|
@ -43,10 +49,26 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
|
|||
}
|
||||
|
||||
MockByteVectorValues(byte[][] values, int dimension, byte[][] denseValues, int numVectors) {
|
||||
super(values, dimension, denseValues, numVectors);
|
||||
this.dimension = dimension;
|
||||
this.values = values;
|
||||
this.denseValues = denseValues;
|
||||
this.numVectors = numVectors;
|
||||
// used by tests that build a graph from bytes rather than floats
|
||||
binaryValue = new BytesRef(dimension);
|
||||
binaryValue.length = dimension;
|
||||
scratch = new byte[dimension];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return values.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockByteVectorValues copy() {
|
||||
return new MockByteVectorValues(
|
||||
|
@ -55,20 +77,20 @@ class MockByteVectorValues extends AbstractMockVectorValues<byte[]>
|
|||
|
||||
@Override
|
||||
public byte[] vectorValue(int ord) {
|
||||
return values[ord];
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
if (LuceneTestCase.random().nextBoolean()) {
|
||||
return values[pos];
|
||||
return values[ord];
|
||||
} else {
|
||||
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
|
||||
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
|
||||
// twice in a
|
||||
// single computation.
|
||||
System.arraycopy(values[pos], 0, scratch, 0, dimension);
|
||||
System.arraycopy(values[ord], 0, scratch, 0, dimension);
|
||||
return scratch;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,11 +17,15 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
class MockVectorValues extends AbstractMockVectorValues<float[]>
|
||||
implements RandomAccessVectorValues.Floats {
|
||||
class MockVectorValues extends FloatVectorValues {
|
||||
private final int dimension;
|
||||
private final float[][] denseValues;
|
||||
protected final float[][] values;
|
||||
private final int numVectors;
|
||||
private final float[] scratch;
|
||||
|
||||
static MockVectorValues fromValues(float[][] values) {
|
||||
|
@ -43,10 +47,23 @@ class MockVectorValues extends AbstractMockVectorValues<float[]>
|
|||
}
|
||||
|
||||
MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) {
|
||||
super(values, dimension, denseValues, numVectors);
|
||||
this.dimension = dimension;
|
||||
this.values = values;
|
||||
this.denseValues = denseValues;
|
||||
this.numVectors = numVectors;
|
||||
this.scratch = new float[dimension];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return values.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockVectorValues copy() {
|
||||
return new MockVectorValues(
|
||||
|
@ -54,20 +71,20 @@ class MockVectorValues extends AbstractMockVectorValues<float[]>
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
public float[] vectorValue(int ord) {
|
||||
if (LuceneTestCase.random().nextBoolean()) {
|
||||
return values[pos];
|
||||
return values[ord];
|
||||
} else {
|
||||
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
|
||||
// This should help us catch cases of aliasing where the same vector values source is used
|
||||
// twice in a single computation.
|
||||
System.arraycopy(values[pos], 0, scratch, 0, dimension);
|
||||
System.arraycopy(values[ord], 0, scratch, 0, dimension);
|
||||
return scratch;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return denseValues[targetOrd];
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,13 +17,12 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -56,7 +55,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
|
|||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<byte[]> vectorValues(int size, int dimension) {
|
||||
MockByteVectorValues vectorValues(int size, int dimension) {
|
||||
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
|
||||
}
|
||||
|
||||
|
@ -65,7 +64,7 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
|
|||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<byte[]> vectorValues(float[][] values) {
|
||||
MockByteVectorValues vectorValues(float[][] values) {
|
||||
byte[][] bValues = new byte[values.length][];
|
||||
// The case when all floats fit within a byte already.
|
||||
boolean scaleSimple = fitsInByte(values[0][0]);
|
||||
|
@ -86,42 +85,35 @@ public class TestHnswByteVectorGraph extends HnswGraphTestCase<byte[]> {
|
|||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<byte[]> vectorValues(
|
||||
int size,
|
||||
int dimension,
|
||||
AbstractMockVectorValues<byte[]> pregeneratedVectorValues,
|
||||
int pregeneratedOffset) {
|
||||
MockByteVectorValues vectorValues(
|
||||
int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) {
|
||||
|
||||
MockByteVectorValues pvv = (MockByteVectorValues) pregeneratedVectorValues;
|
||||
byte[][] vectors = new byte[size][];
|
||||
byte[][] randomVectors =
|
||||
createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random());
|
||||
byte[][] randomVectors = createRandomByteVectors(size - pvv.values.length, dimension, random());
|
||||
|
||||
for (int i = 0; i < pregeneratedOffset; i++) {
|
||||
vectors[i] = randomVectors[i];
|
||||
}
|
||||
|
||||
int currentDoc;
|
||||
while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
|
||||
for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) {
|
||||
vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd];
|
||||
}
|
||||
|
||||
for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
|
||||
i < vectors.length;
|
||||
i++) {
|
||||
vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
|
||||
for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) {
|
||||
vectors[i] = randomVectors[i - pvv.values.length];
|
||||
}
|
||||
|
||||
return MockByteVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<byte[]> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException {
|
||||
MockByteVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException {
|
||||
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
|
||||
byte[][] vectors = new byte[reader.maxDoc()][];
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
vectors[vectorValues.docID()] =
|
||||
ArrayUtil.copyOfSubArray(
|
||||
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
|
||||
for (int i = 0; i < vectorValues.size(); i++) {
|
||||
vectors[vectorValues.ordToDoc(i)] =
|
||||
ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension());
|
||||
}
|
||||
return MockByteVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
|
|
@ -17,13 +17,12 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -60,52 +59,44 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
|
|||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
|
||||
MockVectorValues vectorValues(int size, int dimension) {
|
||||
return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random()));
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(float[][] values) {
|
||||
MockVectorValues vectorValues(float[][] values) {
|
||||
return MockVectorValues.fromValues(values);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException {
|
||||
MockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException {
|
||||
FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName);
|
||||
float[][] vectors = new float[reader.maxDoc()][];
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
vectors[vectorValues.docID()] =
|
||||
ArrayUtil.copyOfSubArray(
|
||||
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
|
||||
for (int i = 0; i < vectorValues.size(); i++) {
|
||||
vectors[vectorValues.ordToDoc(i)] =
|
||||
ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension());
|
||||
}
|
||||
return MockVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(
|
||||
int size,
|
||||
int dimension,
|
||||
AbstractMockVectorValues<float[]> pregeneratedVectorValues,
|
||||
int pregeneratedOffset) {
|
||||
MockVectorValues vectorValues(
|
||||
int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) {
|
||||
MockVectorValues pvv = (MockVectorValues) pregeneratedVectorValues;
|
||||
float[][] vectors = new float[size][];
|
||||
float[][] randomVectors =
|
||||
createRandomFloatVectors(
|
||||
size - pregeneratedVectorValues.values.length, dimension, random());
|
||||
createRandomFloatVectors(size - pvv.values.length, dimension, random());
|
||||
|
||||
for (int i = 0; i < pregeneratedOffset; i++) {
|
||||
vectors[i] = randomVectors[i];
|
||||
}
|
||||
|
||||
int currentDoc;
|
||||
while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc];
|
||||
for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) {
|
||||
vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd];
|
||||
}
|
||||
|
||||
for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length;
|
||||
i < vectors.length;
|
||||
i++) {
|
||||
vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length];
|
||||
for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) {
|
||||
vectors[i] = randomVectors[i - pvv.values.length];
|
||||
}
|
||||
|
||||
return MockVectorValues.fromValues(vectors);
|
||||
|
@ -129,7 +120,7 @@ public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
|
|||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||
int nDoc = 1000;
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
RandomAccessVectorValues.Floats vectors = circularVectorValues(nDoc);
|
||||
FloatVectorValues vectors = circularVectorValues(nDoc);
|
||||
RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors);
|
||||
HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.size());
|
||||
|
|
|
@ -138,12 +138,6 @@ public class TestHnswUtil extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
MockGraph graph = new MockGraph(nodes);
|
||||
/**/
|
||||
if (i == 2) {
|
||||
System.out.println("iter " + i);
|
||||
System.out.print(graph.toString());
|
||||
}
|
||||
/**/
|
||||
assertEquals(isRooted(nodes), HnswUtil.isRooted(graph));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,8 +59,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
|
||||
|
@ -92,8 +91,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectorsNormalized(
|
||||
|
@ -129,8 +127,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
|
||||
|
@ -162,8 +159,7 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
floatVectorValues, confidenceInterval, floats.length, (byte) 7);
|
||||
ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(
|
||||
|
@ -242,11 +238,8 @@ public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
|||
float[][] floats, Set<Integer> deletedVectors) {
|
||||
return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats, deletedVectors) {
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
if (curDoc == -1 || curDoc >= floats.length) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
float[] v = ArrayUtil.copyArray(floats[curDoc]);
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
float[] v = ArrayUtil.copyArray(floats[ordToDoc[ord]]);
|
||||
VectorUtil.l2normalize(v);
|
||||
return v;
|
||||
}
|
||||
|
|
|
@ -272,14 +272,27 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
static class TestSimpleFloatVectorValues extends FloatVectorValues {
|
||||
protected final float[][] floats;
|
||||
protected final Set<Integer> deletedVectors;
|
||||
protected final int[] ordToDoc;
|
||||
protected final int numLiveVectors;
|
||||
protected int curDoc = -1;
|
||||
|
||||
TestSimpleFloatVectorValues(float[][] values, Set<Integer> deletedVectors) {
|
||||
this.floats = values;
|
||||
this.deletedVectors = deletedVectors;
|
||||
this.numLiveVectors =
|
||||
numLiveVectors =
|
||||
deletedVectors == null ? values.length : values.length - deletedVectors.size();
|
||||
ordToDoc = new int[numLiveVectors];
|
||||
if (deletedVectors == null) {
|
||||
for (int i = 0; i < numLiveVectors; i++) {
|
||||
ordToDoc[i] = i;
|
||||
}
|
||||
} else {
|
||||
int ord = 0;
|
||||
for (int doc = 0; doc < values.length; doc++) {
|
||||
if (!deletedVectors.contains(doc)) {
|
||||
ordToDoc[ord++] = doc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -293,40 +306,64 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
if (curDoc == -1 || curDoc >= floats.length) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
return floats[curDoc];
|
||||
public float[] vectorValue(int ord) throws IOException {
|
||||
return floats[ordToDoc(ord)];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curDoc >= floats.length) {
|
||||
return NO_MORE_DOCS;
|
||||
public int ordToDoc(int ord) {
|
||||
return ordToDoc[ord];
|
||||
}
|
||||
return curDoc;
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return new DocIndexIterator() {
|
||||
|
||||
int ord = -1;
|
||||
int doc = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
while (++curDoc < floats.length) {
|
||||
if (deletedVectors == null || !deletedVectors.contains(curDoc)) {
|
||||
return curDoc;
|
||||
while (doc < floats.length - 1) {
|
||||
++doc;
|
||||
if (deletedVectors == null || !deletedVectors.contains(doc)) {
|
||||
++ord;
|
||||
return doc;
|
||||
}
|
||||
}
|
||||
return docID();
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int index() {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return floats.length - deletedVectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
curDoc = target - 1;
|
||||
return nextDoc();
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TestSimpleFloatVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2285,7 +2285,6 @@ public class MemoryIndex {
|
|||
|
||||
private static final class MemoryFloatVectorValues extends FloatVectorValues {
|
||||
private final Info info;
|
||||
private int currentDoc = -1;
|
||||
|
||||
MemoryFloatVectorValues(Info info) {
|
||||
this.info = info;
|
||||
|
@ -2302,14 +2301,19 @@ public class MemoryIndex {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() {
|
||||
if (currentDoc == 0) {
|
||||
public float[] vectorValue(int ord) {
|
||||
if (ord == 0) {
|
||||
return info.floatVectorValues[0];
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) {
|
||||
if (query.length != info.fieldInfo.getVectorDimension()) {
|
||||
|
@ -2320,50 +2324,31 @@ public class MemoryIndex {
|
|||
+ info.fieldInfo.getVectorDimension());
|
||||
}
|
||||
MemoryFloatVectorValues vectorValues = new MemoryFloatVectorValues(info);
|
||||
DocIndexIterator iterator = vectorValues.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
assert iterator.docID() == 0;
|
||||
return info.fieldInfo
|
||||
.getVectorSimilarityFunction()
|
||||
.compare(vectorValues.vectorValue(), query);
|
||||
.compare(vectorValues.vectorValue(0), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return vectorValues;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return currentDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
int doc = ++currentDoc;
|
||||
if (doc == 0) {
|
||||
return doc;
|
||||
} else {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
if (target == 0) {
|
||||
currentDoc = target;
|
||||
return target;
|
||||
} else {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
public MemoryFloatVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
private static final class MemoryByteVectorValues extends ByteVectorValues {
|
||||
private final Info info;
|
||||
private int currentDoc = -1;
|
||||
|
||||
MemoryByteVectorValues(Info info) {
|
||||
this.info = info;
|
||||
|
@ -2380,14 +2365,19 @@ public class MemoryIndex {
|
|||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() {
|
||||
if (currentDoc == 0) {
|
||||
public byte[] vectorValue(int ord) {
|
||||
if (ord == 0) {
|
||||
return info.byteVectorValues[0];
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIndexIterator iterator() {
|
||||
return createDenseIterator();
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(byte[] query) {
|
||||
if (query.length != info.fieldInfo.getVectorDimension()) {
|
||||
|
@ -2398,44 +2388,26 @@ public class MemoryIndex {
|
|||
+ info.fieldInfo.getVectorDimension());
|
||||
}
|
||||
MemoryByteVectorValues vectorValues = new MemoryByteVectorValues(info);
|
||||
DocIndexIterator iterator = vectorValues.iterator();
|
||||
return new VectorScorer() {
|
||||
@Override
|
||||
public float score() {
|
||||
assert iterator.docID() == 0;
|
||||
return info.fieldInfo
|
||||
.getVectorSimilarityFunction()
|
||||
.compare(vectorValues.vectorValue(), query);
|
||||
.compare(vectorValues.vectorValue(0), query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return vectorValues;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return currentDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
int doc = ++currentDoc;
|
||||
if (doc == 0) {
|
||||
return doc;
|
||||
} else {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
if (target == 0) {
|
||||
currentDoc = target;
|
||||
return target;
|
||||
} else {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
public MemoryByteVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,6 +63,7 @@ import org.apache.lucene.index.IndexOptions;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexableField;
|
||||
import org.apache.lucene.index.IndexableFieldType;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.NumericDocValues;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
|
@ -851,9 +852,10 @@ public class TestMemoryIndex extends LuceneTestCase {
|
|||
.reader()
|
||||
.getFloatVectorValues(fieldName);
|
||||
assertNotNull(fvv);
|
||||
assertEquals(0, fvv.nextDoc());
|
||||
assertArrayEquals(expected, fvv.vectorValue(), 1e-6f);
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, fvv.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = fvv.iterator();
|
||||
assertEquals(0, iterator.nextDoc());
|
||||
assertArrayEquals(expected, fvv.vectorValue(0), 1e-6f);
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
|
||||
}
|
||||
|
||||
private static void assertFloatVectorScore(
|
||||
|
@ -868,7 +870,7 @@ public class TestMemoryIndex extends LuceneTestCase {
|
|||
.getFloatVectorValues(fieldName);
|
||||
assertNotNull(fvv);
|
||||
if (random().nextBoolean()) {
|
||||
fvv.nextDoc();
|
||||
fvv.iterator().nextDoc();
|
||||
}
|
||||
VectorScorer scorer = fvv.scorer(queryVector);
|
||||
assertEquals(0, scorer.iterator().nextDoc());
|
||||
|
@ -886,9 +888,10 @@ public class TestMemoryIndex extends LuceneTestCase {
|
|||
.reader()
|
||||
.getByteVectorValues(fieldName);
|
||||
assertNotNull(bvv);
|
||||
assertEquals(0, bvv.nextDoc());
|
||||
assertArrayEquals(expected, bvv.vectorValue());
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, bvv.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = bvv.iterator();
|
||||
assertEquals(0, iterator.nextDoc());
|
||||
assertArrayEquals(expected, bvv.vectorValue(0));
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc());
|
||||
}
|
||||
|
||||
private static void assertByteVectorScore(
|
||||
|
@ -903,7 +906,7 @@ public class TestMemoryIndex extends LuceneTestCase {
|
|||
.getByteVectorValues(fieldName);
|
||||
assertNotNull(bvv);
|
||||
if (random().nextBoolean()) {
|
||||
bvv.nextDoc();
|
||||
bvv.iterator().nextDoc();
|
||||
}
|
||||
VectorScorer scorer = bvv.scorer(queryVector);
|
||||
assertEquals(0, scorer.iterator().nextDoc());
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
|
@ -63,11 +64,12 @@ public class ByteKnnVectorFieldSource extends ValueSource {
|
|||
}
|
||||
|
||||
return new VectorFieldFunction(this) {
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
|
||||
@Override
|
||||
public byte[] byteVectorVal(int doc) throws IOException {
|
||||
if (exists(doc)) {
|
||||
return vectorValues.vectorValue();
|
||||
return vectorValues.vectorValue(iterator.index());
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
@ -75,7 +77,7 @@ public class ByteKnnVectorFieldSource extends ValueSource {
|
|||
|
||||
@Override
|
||||
protected DocIdSetIterator getVectorIterator() {
|
||||
return vectorValues;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
|||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
|
@ -62,11 +63,12 @@ public class FloatKnnVectorFieldSource extends ValueSource {
|
|||
}
|
||||
|
||||
return new VectorFieldFunction(this) {
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
|
||||
@Override
|
||||
public float[] floatVectorVal(int doc) throws IOException {
|
||||
if (exists(doc)) {
|
||||
return vectorValues.vectorValue();
|
||||
return vectorValues.vectorValue(iterator.index());
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
|
@ -74,7 +76,7 @@ public class FloatKnnVectorFieldSource extends ValueSource {
|
|||
|
||||
@Override
|
||||
protected DocIdSetIterator getVectorIterator() {
|
||||
return vectorValues;
|
||||
return iterator;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -25,11 +25,11 @@ import java.util.HashSet;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/** KMeans clustering algorithm for vectors */
|
||||
public class KMeans {
|
||||
|
@ -38,7 +38,7 @@ public class KMeans {
|
|||
public static final int DEFAULT_ITRS = 10;
|
||||
public static final int DEFAULT_SAMPLE_SIZE = 100_000;
|
||||
|
||||
private final RandomAccessVectorValues.Floats vectors;
|
||||
private final FloatVectorValues vectors;
|
||||
private final int numVectors;
|
||||
private final int numCentroids;
|
||||
private final Random random;
|
||||
|
@ -57,9 +57,7 @@ public class KMeans {
|
|||
* @throws IOException when if there is an error accessing vectors
|
||||
*/
|
||||
public static Results cluster(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int numClusters)
|
||||
FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters)
|
||||
throws IOException {
|
||||
return cluster(
|
||||
vectors,
|
||||
|
@ -93,7 +91,7 @@ public class KMeans {
|
|||
* @throws IOException if there is error accessing vectors
|
||||
*/
|
||||
public static Results cluster(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
FloatVectorValues vectors,
|
||||
int numClusters,
|
||||
boolean assignCentroidsToVectors,
|
||||
long seed,
|
||||
|
@ -124,7 +122,7 @@ public class KMeans {
|
|||
if (numClusters == 1) {
|
||||
centroids = new float[1][vectors.dimension()];
|
||||
} else {
|
||||
RandomAccessVectorValues.Floats sampleVectors =
|
||||
FloatVectorValues sampleVectors =
|
||||
vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed);
|
||||
KMeans kmeans =
|
||||
new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters);
|
||||
|
@ -142,7 +140,7 @@ public class KMeans {
|
|||
}
|
||||
|
||||
private KMeans(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
FloatVectorValues vectors,
|
||||
int numCentroids,
|
||||
Random random,
|
||||
KmeansInitializationMethod initializationMethod,
|
||||
|
@ -276,7 +274,7 @@ public class KMeans {
|
|||
* @throws IOException if there is an error accessing vector values
|
||||
*/
|
||||
private static double runKMeansStep(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
FloatVectorValues vectors,
|
||||
float[][] centroids,
|
||||
short[] docCentroids,
|
||||
boolean useKahanSummation,
|
||||
|
@ -348,9 +346,7 @@ public class KMeans {
|
|||
* descending distance to the current centroid set
|
||||
*/
|
||||
static void assignCentroids(
|
||||
RandomAccessVectorValues.Floats vectors,
|
||||
float[][] centroids,
|
||||
List<Integer> unassignedCentroidsIdxs)
|
||||
FloatVectorValues vectors, float[][] centroids, List<Integer> unassignedCentroidsIdxs)
|
||||
throws IOException {
|
||||
int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()];
|
||||
int assignedIndex = 0;
|
||||
|
|
|
@ -20,18 +20,18 @@ package org.apache.lucene.sandbox.codecs.quantization;
|
|||
import java.io.IOException;
|
||||
import java.util.Random;
|
||||
import java.util.function.IntUnaryOperator;
|
||||
import org.apache.lucene.codecs.lucene95.HasIndexSlice;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/** A reader of vector values that samples a subset of the vectors. */
|
||||
public class SampleReader implements RandomAccessVectorValues.Floats {
|
||||
private final RandomAccessVectorValues.Floats origin;
|
||||
public class SampleReader extends FloatVectorValues implements HasIndexSlice {
|
||||
private final FloatVectorValues origin;
|
||||
private final int sampleSize;
|
||||
private final IntUnaryOperator sampleFunction;
|
||||
|
||||
SampleReader(
|
||||
RandomAccessVectorValues.Floats origin, int sampleSize, IntUnaryOperator sampleFunction) {
|
||||
SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) {
|
||||
this.origin = origin;
|
||||
this.sampleSize = sampleSize;
|
||||
this.sampleFunction = sampleFunction;
|
||||
|
@ -48,13 +48,13 @@ public class SampleReader implements RandomAccessVectorValues.Floats {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Floats copy() throws IOException {
|
||||
public FloatVectorValues copy() throws IOException {
|
||||
throw new IllegalStateException("Not supported");
|
||||
}
|
||||
|
||||
@Override
|
||||
public IndexInput getSlice() {
|
||||
return origin.getSlice();
|
||||
return ((HasIndexSlice) origin).getSlice();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -77,8 +77,7 @@ public class SampleReader implements RandomAccessVectorValues.Floats {
|
|||
throw new IllegalStateException("Not supported");
|
||||
}
|
||||
|
||||
public static SampleReader createSampleReader(
|
||||
RandomAccessVectorValues.Floats origin, int k, long seed) {
|
||||
public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) {
|
||||
int[] samples = reservoirSample(origin.size(), k, seed);
|
||||
return new SampleReader(origin, samples.length, i -> samples[i]);
|
||||
}
|
||||
|
|
|
@ -20,9 +20,9 @@ package org.apache.lucene.sandbox.codecs.quantization;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
public class TestKMeans extends LuceneTestCase {
|
||||
|
||||
|
@ -32,7 +32,7 @@ public class TestKMeans extends LuceneTestCase {
|
|||
int dims = random().nextInt(2, 20);
|
||||
int randIdx = random().nextInt(VectorSimilarityFunction.values().length);
|
||||
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx];
|
||||
RandomAccessVectorValues.Floats vectors = generateData(nVectors, dims, nClusters);
|
||||
FloatVectorValues vectors = generateData(nVectors, dims, nClusters);
|
||||
|
||||
// default case
|
||||
{
|
||||
|
@ -75,7 +75,7 @@ public class TestKMeans extends LuceneTestCase {
|
|||
// nClusters > nVectors
|
||||
int nClusters = 20;
|
||||
int nVectors = 10;
|
||||
RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters);
|
||||
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
|
||||
KMeans.Results results =
|
||||
KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters);
|
||||
// assert that we get 1 centroid, as nClusters will be adjusted
|
||||
|
@ -87,7 +87,7 @@ public class TestKMeans extends LuceneTestCase {
|
|||
int sampleSize = 2;
|
||||
int nClusters = 2;
|
||||
int nVectors = 300;
|
||||
RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters);
|
||||
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
|
||||
KMeans.KmeansInitializationMethod initializationMethod =
|
||||
KMeans.KmeansInitializationMethod.PLUS_PLUS;
|
||||
KMeans.Results results =
|
||||
|
@ -108,7 +108,7 @@ public class TestKMeans extends LuceneTestCase {
|
|||
// test unassigned centroids
|
||||
int nClusters = 4;
|
||||
int nVectors = 400;
|
||||
RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters);
|
||||
FloatVectorValues vectors = generateData(nVectors, 5, nClusters);
|
||||
KMeans.Results results =
|
||||
KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters);
|
||||
float[][] centroids = results.centroids();
|
||||
|
@ -118,8 +118,7 @@ public class TestKMeans extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private static RandomAccessVectorValues.Floats generateData(
|
||||
int nSamples, int nDims, int nClusters) {
|
||||
private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) {
|
||||
List<float[]> vectors = new ArrayList<>(nSamples);
|
||||
float[][] centroids = new float[nClusters][nDims];
|
||||
// Generate random centroids
|
||||
|
@ -137,6 +136,6 @@ public class TestKMeans extends LuceneTestCase {
|
|||
}
|
||||
vectors.add(vector);
|
||||
}
|
||||
return RandomAccessVectorValues.fromFloats(vectors, nDims);
|
||||
return FloatVectorValues.fromFloats(vectors, nDims);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,7 +125,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
|
||||
FloatVectorValues floatValues = delegate.getFloatVectorValues(field);
|
||||
assert floatValues != null;
|
||||
assert floatValues.docID() == -1;
|
||||
assert floatValues.iterator().docID() == -1;
|
||||
assert floatValues.size() >= 0;
|
||||
assert floatValues.dimension() > 0;
|
||||
return floatValues;
|
||||
|
@ -139,7 +139,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
|
||||
ByteVectorValues values = delegate.getByteVectorValues(field);
|
||||
assert values != null;
|
||||
assert values.docID() == -1;
|
||||
assert values.iterator().docID() == -1;
|
||||
assert values.size() >= 0;
|
||||
assert values.dimension() > 0;
|
||||
return values;
|
||||
|
|
|
@ -55,6 +55,7 @@ import org.apache.lucene.index.IndexOptions;
|
|||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.KnnVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.MergePolicy;
|
||||
|
@ -437,9 +438,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
try (IndexReader reader = DirectoryReader.open(w2)) {
|
||||
LeafReader r = getOnlyLeafReader(reader);
|
||||
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
|
||||
assertEquals(0, vectorValues.nextDoc());
|
||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
assertEquals(0, iterator.nextDoc());
|
||||
assertEquals(0, vectorValues.vectorValue(0)[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -462,9 +464,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
try (IndexReader reader = DirectoryReader.open(w2)) {
|
||||
LeafReader r = getOnlyLeafReader(reader);
|
||||
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
|
||||
assertNotEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
assertNotEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
assertEquals(0, vectorValues.vectorValue(iterator.index())[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -489,12 +492,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
try (IndexReader reader = DirectoryReader.open(w2)) {
|
||||
LeafReader r = getOnlyLeafReader(reader);
|
||||
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
|
||||
assertEquals(0, vectorValues.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
assertEquals(0, iterator.nextDoc());
|
||||
// The merge order is randomized, we might get 0 first, or 1
|
||||
float value = vectorValues.vectorValue()[0];
|
||||
float value = vectorValues.vectorValue(0)[0];
|
||||
assertTrue(value == 0 || value == 1);
|
||||
assertEquals(1, vectorValues.nextDoc());
|
||||
value += vectorValues.vectorValue()[0];
|
||||
assertEquals(1, iterator.nextDoc());
|
||||
value += vectorValues.vectorValue(1)[0];
|
||||
assertEquals(1, value, 0);
|
||||
}
|
||||
}
|
||||
|
@ -879,8 +883,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||
if (byteVectorValues != null) {
|
||||
docCount += byteVectorValues.size();
|
||||
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += byteVectorValues.vectorValue()[0];
|
||||
KnnVectorValues.DocIndexIterator iterator = byteVectorValues.iterator();
|
||||
while (true) {
|
||||
if (!(iterator.nextDoc() != NO_MORE_DOCS)) break;
|
||||
checksum += byteVectorValues.vectorValue(iterator.index())[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -890,8 +896,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName);
|
||||
if (vectorValues != null) {
|
||||
docCount += vectorValues.size();
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += vectorValues.vectorValue()[0];
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
while (true) {
|
||||
if (!(iterator.nextDoc() != NO_MORE_DOCS)) break;
|
||||
checksum += vectorValues.vectorValue(iterator.index())[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -950,10 +958,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
assertSame(iterator, scorer.iterator());
|
||||
assertNotSame(iterator, scorer);
|
||||
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator();
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS) {
|
||||
if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break;
|
||||
float score = scorer.score();
|
||||
assertTrue(score >= 0f);
|
||||
assertEquals(iterator.docID(), vectorValues.docID());
|
||||
assertEquals(iterator.docID(), valuesIterator.docID());
|
||||
}
|
||||
// verify that a new scorer can be obtained after iteration
|
||||
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
|
||||
|
@ -1009,10 +1019,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
assertSame(iterator, scorer.iterator());
|
||||
assertNotSame(iterator, scorer);
|
||||
// verify scorer iteration scores are valid & iteration with vectorValues is consistent
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator();
|
||||
while (iterator.nextDoc() != NO_MORE_DOCS) {
|
||||
if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break;
|
||||
float score = scorer.score();
|
||||
assertTrue(score >= 0f);
|
||||
assertEquals(iterator.docID(), vectorValues.docID());
|
||||
assertEquals(iterator.docID(), valuesIterator.docID());
|
||||
}
|
||||
// verify that a new scorer can be obtained after iteration
|
||||
VectorScorer newScorer = vectorValues.scorer(vectorToScore);
|
||||
|
@ -1118,12 +1130,16 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
LeafReader r = getOnlyLeafReader(reader);
|
||||
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
|
||||
assertEquals(3, vectorValues.size());
|
||||
vectorValues.nextDoc();
|
||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
||||
vectorValues.nextDoc();
|
||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
||||
vectorValues.nextDoc();
|
||||
assertEquals(2, vectorValues.vectorValue()[0], 0);
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
iterator.nextDoc();
|
||||
assertEquals(0, iterator.index());
|
||||
assertEquals(1, vectorValues.vectorValue(0)[0], 0);
|
||||
iterator.nextDoc();
|
||||
assertEquals(1, iterator.index());
|
||||
assertEquals(1, vectorValues.vectorValue(1)[0], 0);
|
||||
iterator.nextDoc();
|
||||
assertEquals(2, iterator.index());
|
||||
assertEquals(2, vectorValues.vectorValue(2)[0], 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1146,13 +1162,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
FloatVectorValues vectorValues = leaf.getFloatVectorValues(fieldName);
|
||||
assertEquals(2, vectorValues.dimension());
|
||||
assertEquals(3, vectorValues.size());
|
||||
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(-1f, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
assertEquals("1", storedFields.document(iterator.nextDoc()).get("id"));
|
||||
assertEquals(-1f, vectorValues.vectorValue(0)[0], 0);
|
||||
assertEquals("2", storedFields.document(iterator.nextDoc()).get("id"));
|
||||
assertEquals(1, vectorValues.vectorValue(1)[0], 0);
|
||||
assertEquals("4", storedFields.document(iterator.nextDoc()).get("id"));
|
||||
assertEquals(0, vectorValues.vectorValue(2)[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1175,13 +1192,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName);
|
||||
assertEquals(2, vectorValues.dimension());
|
||||
assertEquals(3, vectorValues.size());
|
||||
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(-1, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
assertEquals("1", storedFields.document(vectorValues.iterator().nextDoc()).get("id"));
|
||||
assertEquals(-1, vectorValues.vectorValue(0)[0], 0);
|
||||
assertEquals("2", storedFields.document(vectorValues.iterator().nextDoc()).get("id"));
|
||||
assertEquals(1, vectorValues.vectorValue(1)[0], 0);
|
||||
assertEquals("4", storedFields.document(vectorValues.iterator().nextDoc()).get("id"));
|
||||
assertEquals(0, vectorValues.vectorValue(2)[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1211,27 +1228,30 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1");
|
||||
assertEquals(2, vectorValues.dimension());
|
||||
assertEquals(2, vectorValues.size());
|
||||
vectorValues.nextDoc();
|
||||
assertEquals(1f, vectorValues.vectorValue()[0], 0);
|
||||
vectorValues.nextDoc();
|
||||
assertEquals(2f, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
iterator.nextDoc();
|
||||
assertEquals(1f, vectorValues.vectorValue(0)[0], 0);
|
||||
iterator.nextDoc();
|
||||
assertEquals(2f, vectorValues.vectorValue(1)[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, iterator.nextDoc());
|
||||
|
||||
FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2");
|
||||
KnnVectorValues.DocIndexIterator it2 = vectorValues2.iterator();
|
||||
assertEquals(4, vectorValues2.dimension());
|
||||
assertEquals(2, vectorValues2.size());
|
||||
vectorValues2.nextDoc();
|
||||
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
|
||||
vectorValues2.nextDoc();
|
||||
assertEquals(2f, vectorValues2.vectorValue()[1], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc());
|
||||
it2.nextDoc();
|
||||
assertEquals(2f, vectorValues2.vectorValue(0)[1], 0);
|
||||
it2.nextDoc();
|
||||
assertEquals(2f, vectorValues2.vectorValue(1)[1], 0);
|
||||
assertEquals(NO_MORE_DOCS, it2.nextDoc());
|
||||
|
||||
FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3");
|
||||
assertEquals(4, vectorValues3.dimension());
|
||||
assertEquals(1, vectorValues3.size());
|
||||
vectorValues3.nextDoc();
|
||||
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator it3 = vectorValues3.iterator();
|
||||
it3.nextDoc();
|
||||
assertEquals(1f, vectorValues3.vectorValue(0)[0], 0.1);
|
||||
assertEquals(NO_MORE_DOCS, it3.nextDoc());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1295,13 +1315,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
totalSize += vectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
int docId;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
float[] v = vectorValues.vectorValue();
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
while (true) {
|
||||
if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break;
|
||||
float[] v = vectorValues.vectorValue(iterator.index());
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) {
|
||||
assertArrayEquals(idString, values[id], v, 0);
|
||||
assertArrayEquals(idString + " " + docId, values[id], v, 0);
|
||||
++valueCount;
|
||||
} else {
|
||||
++numDeletes;
|
||||
|
@ -1375,8 +1397,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
totalSize += vectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
int docId;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
byte[] v = vectorValues.vectorValue();
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
while (true) {
|
||||
if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break;
|
||||
byte[] v = vectorValues.vectorValue(iterator.index());
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
|
@ -1495,8 +1519,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
int docId;
|
||||
int numLiveDocsWithVectors = 0;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
float[] v = vectorValues.vectorValue();
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
while (true) {
|
||||
if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break;
|
||||
float[] v = vectorValues.vectorValue(iterator.index());
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
|
@ -1703,25 +1729,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName);
|
||||
int[] vectorDocs = new int[vectorValues.size() + 1];
|
||||
int cur = -1;
|
||||
KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator();
|
||||
while (++cur < vectorValues.size() + 1) {
|
||||
vectorDocs[cur] = vectorValues.nextDoc();
|
||||
vectorDocs[cur] = iterator.nextDoc();
|
||||
if (cur != 0) {
|
||||
assertTrue(vectorDocs[cur] > vectorDocs[cur - 1]);
|
||||
}
|
||||
}
|
||||
vectorValues = r.getFloatVectorValues(fieldName);
|
||||
DocIdSetIterator iter = vectorValues.iterator();
|
||||
cur = -1;
|
||||
for (int i = 0; i < numdocs; i++) {
|
||||
// randomly advance to i
|
||||
if (random().nextInt(4) == 3) {
|
||||
while (vectorDocs[++cur] < i) {}
|
||||
assertEquals(vectorDocs[cur], vectorValues.advance(i));
|
||||
assertEquals(vectorDocs[cur], vectorValues.docID());
|
||||
if (vectorValues.docID() == NO_MORE_DOCS) {
|
||||
assertEquals(vectorDocs[cur], iter.advance(i));
|
||||
assertEquals(vectorDocs[cur], iter.docID());
|
||||
if (iter.docID() == NO_MORE_DOCS) {
|
||||
break;
|
||||
}
|
||||
// make i equal to docid so that it is greater than docId in the next loop iteration
|
||||
i = vectorValues.docID();
|
||||
i = iter.docID();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1772,6 +1800,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
double checksum = 0;
|
||||
int docCount = 0;
|
||||
long sumDocIds = 0;
|
||||
long sumOrdToDocIds = 0;
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> {
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
|
@ -1779,11 +1808,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
if (byteVectorValues != null) {
|
||||
docCount += byteVectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += byteVectorValues.vectorValue()[0];
|
||||
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
|
||||
KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator();
|
||||
for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) {
|
||||
int ord = iter.index();
|
||||
checksum += byteVectorValues.vectorValue(ord)[0];
|
||||
Document doc = storedFields.document(iter.docID(), Set.of("id"));
|
||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
for (int ord = 0; ord < byteVectorValues.size(); ord++) {
|
||||
Document doc =
|
||||
storedFields.document(byteVectorValues.ordToDoc(ord), Set.of("id"));
|
||||
sumOrdToDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1793,11 +1829,17 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
if (vectorValues != null) {
|
||||
docCount += vectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += vectorValues.vectorValue()[0];
|
||||
Document doc = storedFields.document(vectorValues.docID(), Set.of("id"));
|
||||
KnnVectorValues.DocIndexIterator iter = vectorValues.iterator();
|
||||
for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) {
|
||||
int ord = iter.index();
|
||||
checksum += vectorValues.vectorValue(ord)[0];
|
||||
Document doc = storedFields.document(iter.docID(), Set.of("id"));
|
||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
for (int ord = 0; ord < vectorValues.size(); ord++) {
|
||||
Document doc = storedFields.document(vectorValues.ordToDoc(ord), Set.of("id"));
|
||||
sumOrdToDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1809,6 +1851,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
|
||||
assertEquals(fieldDocCount, docCount);
|
||||
assertEquals(fieldSumDocIDs, sumDocIds);
|
||||
assertEquals(fieldSumDocIDs, sumOrdToDocIds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1839,25 +1882,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
|
||||
ByteVectorValues byteVectors = leafReader.getByteVectorValues("byte");
|
||||
assertNotNull(byteVectors);
|
||||
assertEquals(0, byteVectors.nextDoc());
|
||||
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue());
|
||||
assertEquals(1, byteVectors.nextDoc());
|
||||
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue());
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, byteVectors.nextDoc());
|
||||
KnnVectorValues.DocIndexIterator iter = byteVectors.iterator();
|
||||
assertEquals(0, iter.nextDoc());
|
||||
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(0));
|
||||
assertEquals(1, iter.nextDoc());
|
||||
assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(1));
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc());
|
||||
|
||||
FloatVectorValues floatVectors = leafReader.getFloatVectorValues("float");
|
||||
assertNotNull(floatVectors);
|
||||
assertEquals(0, floatVectors.nextDoc());
|
||||
float[] vector = floatVectors.vectorValue();
|
||||
iter = floatVectors.iterator();
|
||||
assertEquals(0, iter.nextDoc());
|
||||
float[] vector = floatVectors.vectorValue(0);
|
||||
assertEquals(2, vector.length);
|
||||
assertEquals(1f, vector[0], 0f);
|
||||
assertEquals(2f, vector[1], 0f);
|
||||
assertEquals(1, floatVectors.nextDoc());
|
||||
vector = floatVectors.vectorValue();
|
||||
assertEquals(1, iter.nextDoc());
|
||||
vector = floatVectors.vectorValue(1);
|
||||
assertEquals(2, vector.length);
|
||||
assertEquals(1f, vector[0], 0f);
|
||||
assertEquals(2f, vector[1], 0f);
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, floatVectors.nextDoc());
|
||||
assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc());
|
||||
|
||||
IOUtils.close(reader, w2, dir1, dir2);
|
||||
}
|
||||
|
|
|
@ -183,7 +183,7 @@ public class AssertingScorer extends Scorer {
|
|||
} else {
|
||||
state = IteratorState.ITERATING;
|
||||
}
|
||||
assert in.docID() == advanced;
|
||||
assert in.docID() == advanced : in.docID() + " != " + advanced + " in " + in;
|
||||
assert AssertingScorer.this.in.docID() == in.docID();
|
||||
return doc = advanced;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue