mirror of https://github.com/apache/lucene.git
Expose FlatVectorsFormat (#13469)
This commit is contained in:
parent
048770205c
commit
487d24ae69
|
@ -52,11 +52,6 @@ class Lucene99RWHnswScalarQuantizationVectorsFormat
|
|||
null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return 1024;
|
||||
}
|
||||
|
||||
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
|
||||
private static final FlatVectorsFormat rawVectorFormat =
|
||||
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
|
||||
|
|
|
@ -76,7 +76,8 @@ module org.apache.lucene.core {
|
|||
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
provides org.apache.lucene.codecs.KnnVectorsFormat with
|
||||
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
|
||||
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
|
||||
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
|
||||
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
|
||||
provides org.apache.lucene.codecs.PostingsFormat with
|
||||
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
|
||||
provides org.apache.lucene.index.SortFieldProvider with
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
|
@ -27,14 +28,23 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class FlatVectorsFormat {
|
||||
public abstract class FlatVectorsFormat extends KnnVectorsFormat {
|
||||
|
||||
/** Sole constructor */
|
||||
protected FlatVectorsFormat() {}
|
||||
protected FlatVectorsFormat(String name) {
|
||||
super(name);
|
||||
}
|
||||
|
||||
/** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */
|
||||
@Override
|
||||
public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException;
|
||||
|
||||
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
|
||||
@Override
|
||||
public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException;
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return 1024;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,12 +17,11 @@
|
|||
|
||||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
/**
|
||||
|
@ -39,7 +38,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class FlatVectorsReader implements Closeable, Accountable {
|
||||
public abstract class FlatVectorsReader extends KnnVectorsReader implements Accountable {
|
||||
|
||||
/** Scorer for flat vectors */
|
||||
protected final FlatVectorsScorer vectorScorer;
|
||||
|
@ -56,6 +55,18 @@ public abstract class FlatVectorsReader implements Closeable, Accountable {
|
|||
return vectorScorer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
// don't scan stored field data. If we didn't index it, produce no search results
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
// don't scan stored field data. If we didn't index it, produce no search results
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a {@link RandomVectorScorer} for the given field and target vector.
|
||||
*
|
||||
|
@ -77,28 +88,4 @@ public abstract class FlatVectorsReader implements Closeable, Accountable {
|
|||
*/
|
||||
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
|
||||
throws IOException;
|
||||
|
||||
/**
|
||||
* Checks consistency of this reader.
|
||||
*
|
||||
* <p>Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value
|
||||
* against large data files.
|
||||
*
|
||||
* @lucene.internal
|
||||
*/
|
||||
public abstract void checkIntegrity() throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if
|
||||
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
|
||||
* never {@code null}.
|
||||
*/
|
||||
public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
|
||||
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
|
||||
* never {@code null}.
|
||||
*/
|
||||
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
|
||||
}
|
||||
|
|
|
@ -17,14 +17,11 @@
|
|||
|
||||
package org.apache.lucene.codecs.hnsw;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||
|
||||
/**
|
||||
|
@ -32,7 +29,7 @@ import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class FlatVectorsWriter implements Accountable, Closeable {
|
||||
public abstract class FlatVectorsWriter extends KnnVectorsWriter {
|
||||
/** Scorer for flat vectors */
|
||||
protected final FlatVectorsScorer vectorsScorer;
|
||||
|
||||
|
@ -60,6 +57,11 @@ public abstract class FlatVectorsWriter implements Accountable, Closeable {
|
|||
public abstract FlatFieldVectorsWriter<?> addField(
|
||||
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;
|
||||
|
||||
@Override
|
||||
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
return addField(fieldInfo, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
|
||||
* any additional merging logic can be implemented by the user of this class.
|
||||
|
@ -72,15 +74,4 @@ public abstract class FlatVectorsWriter implements Accountable, Closeable {
|
|||
*/
|
||||
public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
|
||||
FieldInfo fieldInfo, MergeState mergeState) throws IOException;
|
||||
|
||||
/** Write field for merging */
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState));
|
||||
}
|
||||
|
||||
/** Called once at the end before close */
|
||||
public abstract void finish() throws IOException;
|
||||
|
||||
/** Flush all buffered data on disk * */
|
||||
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
|
||||
}
|
||||
|
|
|
@ -67,6 +67,7 @@ import org.apache.lucene.store.IndexOutput;
|
|||
*/
|
||||
public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
|
||||
|
||||
static final String NAME = "Lucene99FlatVectorsFormat";
|
||||
static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta";
|
||||
static final String VECTOR_DATA_CODEC_NAME = "Lucene99FlatVectorsFormatData";
|
||||
static final String META_EXTENSION = "vemf";
|
||||
|
@ -80,6 +81,7 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
|
|||
|
||||
/** Constructs a format */
|
||||
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
|
||||
super(NAME);
|
||||
this.vectorsScorer = vectorsScorer;
|
||||
}
|
||||
|
||||
|
|
|
@ -119,6 +119,11 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
|||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
return addField(fieldInfo, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter<?> field : fields) {
|
||||
|
|
|
@ -89,6 +89,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
|||
*/
|
||||
public Lucene99ScalarQuantizedVectorsFormat(
|
||||
Float confidenceInterval, int bits, boolean compress) {
|
||||
super(NAME);
|
||||
if (confidenceInterval != null
|
||||
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
|
||||
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL
|
||||
|
|
|
@ -45,11 +45,14 @@ import java.util.function.Supplier;
|
|||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.DocValuesProducer;
|
||||
import org.apache.lucene.codecs.FieldsProducer;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.NormsProducer;
|
||||
import org.apache.lucene.codecs.PointsReader;
|
||||
import org.apache.lucene.codecs.PostingsFormat;
|
||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.DocumentStoredFieldVisitor;
|
||||
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
|
||||
|
@ -2742,6 +2745,14 @@ public final class CheckIndex implements Closeable {
|
|||
return status;
|
||||
}
|
||||
|
||||
private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
|
||||
KnnVectorsReader vectorsReader = codecReader.getVectorReader();
|
||||
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {
|
||||
vectorsReader = perFieldReader.getFieldReader(fieldName);
|
||||
}
|
||||
return (vectorsReader instanceof FlatVectorsReader) == false;
|
||||
}
|
||||
|
||||
private static void checkFloatVectorValues(
|
||||
FloatVectorValues values,
|
||||
FieldInfo fieldInfo,
|
||||
|
@ -2754,11 +2765,15 @@ public final class CheckIndex implements Closeable {
|
|||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
|
||||
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
|
||||
TopDocs docs = collector.topDocs();
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) {
|
||||
codecReader
|
||||
.getVectorReader()
|
||||
.search(fieldInfo.name, values.vectorValue(), collector, null);
|
||||
TopDocs docs = collector.topDocs();
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
}
|
||||
}
|
||||
}
|
||||
int valueLength = values.vectorValue().length;
|
||||
|
@ -2794,9 +2809,10 @@ public final class CheckIndex implements Closeable {
|
|||
throws IOException {
|
||||
int docCount = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
if (supportsSearch && values.docID() % everyNdoc == 0) {
|
||||
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
|
||||
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
|
||||
TopDocs docs = collector.topDocs();
|
||||
|
|
|
@ -15,3 +15,4 @@
|
|||
|
||||
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
|
||||
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
|
||||
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static java.lang.String.format;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.hamcrest.Matchers.oneOf;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
import org.junit.Before;
|
||||
|
||||
public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
|
||||
KnnVectorsFormat format;
|
||||
Float confidenceInterval;
|
||||
int bits;
|
||||
|
||||
@Before
|
||||
@Override
|
||||
public void setUp() throws Exception {
|
||||
bits = random().nextBoolean() ? 4 : 7;
|
||||
confidenceInterval = random().nextBoolean() ? random().nextFloat(0.90f, 1.0f) : null;
|
||||
if (random().nextBoolean()) {
|
||||
confidenceInterval = 0f;
|
||||
}
|
||||
format =
|
||||
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, random().nextBoolean());
|
||||
super.setUp();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return format;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public void testSearch() throws Exception {
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
// randomly reuse a vector, this ensures the underlying codec doesn't rely on the array
|
||||
// reference
|
||||
doc.add(
|
||||
new KnnFloatVectorField("f", new float[] {0, 1}, VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w.addDocument(doc);
|
||||
w.commit();
|
||||
try (IndexReader reader = DirectoryReader.open(w)) {
|
||||
LeafReader r = getOnlyLeafReader(reader);
|
||||
if (r instanceof CodecReader codecReader) {
|
||||
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
|
||||
// if this search found any results it would raise NPE attempting to collect them in our
|
||||
// null collector
|
||||
knnVectorsReader.search("f", new float[] {1, 0}, null, null);
|
||||
} else {
|
||||
fail("reader is not CodecReader");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testQuantizedVectorsWriteAndRead() throws Exception {
|
||||
// create lucene directory with codec
|
||||
int numVectors = 1 + random().nextInt(50);
|
||||
VectorSimilarityFunction similarityFunction = randomSimilarity();
|
||||
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
|
||||
int dim = random().nextInt(64) + 1;
|
||||
if (dim % 2 == 1) {
|
||||
dim++;
|
||||
}
|
||||
List<float[]> vectors = new ArrayList<>(numVectors);
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
vectors.add(randomVector(dim));
|
||||
}
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
confidenceInterval != null && confidenceInterval == 0f
|
||||
? ScalarQuantizer.fromVectorsAutoInterval(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
similarityFunction,
|
||||
numVectors,
|
||||
(byte) bits)
|
||||
: ScalarQuantizer.fromVectors(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
confidenceInterval == null
|
||||
? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim)
|
||||
: confidenceInterval,
|
||||
numVectors,
|
||||
(byte) bits);
|
||||
float[] expectedCorrections = new float[numVectors];
|
||||
byte[][] expectedVectors = new byte[numVectors][];
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
float[] vector = vectors.get(i);
|
||||
if (normalize) {
|
||||
float[] copy = new float[vector.length];
|
||||
System.arraycopy(vector, 0, copy, 0, copy.length);
|
||||
VectorUtil.l2normalize(copy);
|
||||
vector = copy;
|
||||
}
|
||||
|
||||
expectedVectors[i] = new byte[dim];
|
||||
expectedCorrections[i] =
|
||||
scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction);
|
||||
}
|
||||
float[] randomlyReusedVector = new float[dim];
|
||||
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w =
|
||||
new IndexWriter(
|
||||
dir,
|
||||
new IndexWriterConfig()
|
||||
.setMaxBufferedDocs(numVectors + 1)
|
||||
.setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH)
|
||||
.setMergePolicy(NoMergePolicy.INSTANCE))) {
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
Document doc = new Document();
|
||||
// randomly reuse a vector, this ensures the underlying codec doesn't rely on the array
|
||||
// reference
|
||||
final float[] v;
|
||||
if (random().nextBoolean()) {
|
||||
System.arraycopy(vectors.get(i), 0, randomlyReusedVector, 0, dim);
|
||||
v = randomlyReusedVector;
|
||||
} else {
|
||||
v = vectors.get(i);
|
||||
}
|
||||
doc.add(new KnnFloatVectorField("f", v, similarityFunction));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.commit();
|
||||
try (IndexReader reader = DirectoryReader.open(w)) {
|
||||
LeafReader r = getOnlyLeafReader(reader);
|
||||
if (r instanceof CodecReader codecReader) {
|
||||
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
|
||||
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
|
||||
knnVectorsReader = fieldsReader.getFieldReader("f");
|
||||
}
|
||||
if (knnVectorsReader instanceof Lucene99ScalarQuantizedVectorsReader quantizedReader) {
|
||||
assertNotNull(quantizedReader.getQuantizationState("f"));
|
||||
QuantizedByteVectorValues quantizedByteVectorValues =
|
||||
quantizedReader.getQuantizedVectorValues("f");
|
||||
int docId = -1;
|
||||
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
byte[] vector = quantizedByteVectorValues.vectorValue();
|
||||
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
|
||||
for (int i = 0; i < dim; i++) {
|
||||
assertEquals(vector[i], expectedVectors[docId][i]);
|
||||
}
|
||||
assertEquals(offset, expectedCorrections[docId], 0.00001f);
|
||||
}
|
||||
} else {
|
||||
fail("reader is not Lucene99ScalarQuantizedVectorsReader");
|
||||
}
|
||||
} else {
|
||||
fail("reader is not CodecReader");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testToString() {
|
||||
FilterCodec customCodec =
|
||||
new FilterCodec("foo", Codec.getDefault()) {
|
||||
@Override
|
||||
public KnnVectorsFormat knnVectorsFormat() {
|
||||
return new Lucene99ScalarQuantizedVectorsFormat(0.9f, (byte) 4, false);
|
||||
}
|
||||
};
|
||||
String expectedPattern =
|
||||
"Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
|
||||
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
|
||||
var memSegScorer =
|
||||
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
|
||||
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
|
||||
}
|
||||
|
||||
public void testLimits() {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99ScalarQuantizedVectorsFormat(1.1f, 7, false));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99ScalarQuantizedVectorsFormat(null, -1, false));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 5, false));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 9, false));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testRandomWithUpdatesAndGraph() {
|
||||
// graph not supported
|
||||
}
|
||||
|
||||
@Override
|
||||
public void testSearchWithVisitedLimit() {
|
||||
// search not supported
|
||||
}
|
||||
}
|
|
@ -102,6 +102,7 @@ import junit.framework.AssertionFailedError;
|
|||
import org.apache.lucene.analysis.Analyzer;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat;
|
||||
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.Field.Store;
|
||||
|
@ -3223,11 +3224,16 @@ public abstract class LuceneTestCase extends Assert {
|
|||
return true;
|
||||
}
|
||||
|
||||
private static boolean supportsVectorSearch(KnnVectorsFormat format) {
|
||||
return (format instanceof FlatVectorsFormat) == false;
|
||||
}
|
||||
|
||||
protected static KnnVectorsFormat randomVectorFormat(VectorEncoding vectorEncoding) {
|
||||
List<KnnVectorsFormat> availableFormats =
|
||||
KnnVectorsFormat.availableKnnVectorsFormats().stream()
|
||||
.map(KnnVectorsFormat::forName)
|
||||
.filter(format -> supportsVectorEncoding(format, vectorEncoding))
|
||||
.filter(format -> supportsVectorSearch(format))
|
||||
.toList();
|
||||
return RandomPicks.randomFrom(random(), availableFormats);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue