Expose FlatVectorsFormat (#13469)

This commit is contained in:
Michael Sokolov 2024-06-13 19:38:24 -04:00 committed by GitHub
parent 048770205c
commit 487d24ae69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 313 additions and 59 deletions

View File

@ -52,11 +52,6 @@ class Lucene99RWHnswScalarQuantizationVectorsFormat
null);
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());

View File

@ -76,7 +76,8 @@ module org.apache.lucene.core {
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
provides org.apache.lucene.codecs.KnnVectorsFormat with
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat,
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with

View File

@ -18,6 +18,7 @@
package org.apache.lucene.codecs.hnsw;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
@ -27,14 +28,23 @@ import org.apache.lucene.index.SegmentWriteState;
*
* @lucene.experimental
*/
public abstract class FlatVectorsFormat {
public abstract class FlatVectorsFormat extends KnnVectorsFormat {
/** Sole constructor */
protected FlatVectorsFormat() {}
protected FlatVectorsFormat(String name) {
super(name);
}
/** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */
@Override
public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException;
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
@Override
public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException;
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
}

View File

@ -17,12 +17,11 @@
package org.apache.lucene.codecs.hnsw;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
@ -39,7 +38,7 @@ import org.apache.lucene.util.hnsw.RandomVectorScorer;
*
* @lucene.experimental
*/
public abstract class FlatVectorsReader implements Closeable, Accountable {
public abstract class FlatVectorsReader extends KnnVectorsReader implements Accountable {
/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorScorer;
@ -56,6 +55,18 @@ public abstract class FlatVectorsReader implements Closeable, Accountable {
return vectorScorer;
}
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
// don't scan stored field data. If we didn't index it, produce no search results
}
/**
* Returns a {@link RandomVectorScorer} for the given field and target vector.
*
@ -77,28 +88,4 @@ public abstract class FlatVectorsReader implements Closeable, Accountable {
*/
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
throws IOException;
/**
* Checks consistency of this reader.
*
* <p>Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value
* against large data files.
*
* @lucene.internal
*/
public abstract void checkIntegrity() throws IOException;
/**
* Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException;
/**
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
}

View File

@ -17,14 +17,11 @@
package org.apache.lucene.codecs.hnsw;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
/**
@ -32,7 +29,7 @@ import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
*
* @lucene.experimental
*/
public abstract class FlatVectorsWriter implements Accountable, Closeable {
public abstract class FlatVectorsWriter extends KnnVectorsWriter {
/** Scorer for flat vectors */
protected final FlatVectorsScorer vectorsScorer;
@ -60,6 +57,11 @@ public abstract class FlatVectorsWriter implements Accountable, Closeable {
public abstract FlatFieldVectorsWriter<?> addField(
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;
@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}
/**
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
* any additional merging logic can be implemented by the user of this class.
@ -72,15 +74,4 @@ public abstract class FlatVectorsWriter implements Accountable, Closeable {
*/
public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
FieldInfo fieldInfo, MergeState mergeState) throws IOException;
/** Write field for merging */
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState));
}
/** Called once at the end before close */
public abstract void finish() throws IOException;
/** Flush all buffered data on disk * */
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
}

View File

@ -67,6 +67,7 @@ import org.apache.lucene.store.IndexOutput;
*/
public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
static final String NAME = "Lucene99FlatVectorsFormat";
static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99FlatVectorsFormatData";
static final String META_EXTENSION = "vemf";
@ -80,6 +81,7 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
/** Constructs a format */
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
super(NAME);
this.vectorsScorer = vectorsScorer;
}

View File

@ -119,6 +119,11 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
return newField;
}
@Override
public FlatFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
return addField(fieldInfo, null);
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {

View File

@ -89,6 +89,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
*/
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
super(NAME);
if (confidenceInterval != null
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
&& (confidenceInterval < MINIMUM_CONFIDENCE_INTERVAL

View File

@ -45,11 +45,14 @@ import java.util.function.Supplier;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsReader;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.DocumentStoredFieldVisitor;
import org.apache.lucene.index.CheckIndex.Status.DocValuesStatus;
@ -2742,6 +2745,14 @@ public final class CheckIndex implements Closeable {
return status;
}
private static boolean vectorsReaderSupportsSearch(CodecReader codecReader, String fieldName) {
KnnVectorsReader vectorsReader = codecReader.getVectorReader();
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader perFieldReader) {
vectorsReader = perFieldReader.getFieldReader(fieldName);
}
return (vectorsReader instanceof FlatVectorsReader) == false;
}
private static void checkFloatVectorValues(
FloatVectorValues values,
FieldInfo fieldInfo,
@ -2754,11 +2765,15 @@ public final class CheckIndex implements Closeable {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) {
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
}
int valueLength = values.vectorValue().length;
@ -2794,9 +2809,10 @@ public final class CheckIndex implements Closeable {
throws IOException {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name);
while (values.nextDoc() != NO_MORE_DOCS) {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
if (supportsSearch && values.docID() % everyNdoc == 0) {
KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE);
codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null);
TopDocs docs = collector.topDocs();

View File

@ -15,3 +15,4 @@
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat

View File

@ -0,0 +1,239 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static java.lang.String.format;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.QuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
import org.junit.Before;
public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
KnnVectorsFormat format;
Float confidenceInterval;
int bits;
@Before
@Override
public void setUp() throws Exception {
bits = random().nextBoolean() ? 4 : 7;
confidenceInterval = random().nextBoolean() ? random().nextFloat(0.90f, 1.0f) : null;
if (random().nextBoolean()) {
confidenceInterval = 0f;
}
format =
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, random().nextBoolean());
super.setUp();
}
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return format;
}
};
}
public void testSearch() throws Exception {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
// randomly reuse a vector, this ensures the underlying codec doesn't rely on the array
// reference
doc.add(
new KnnFloatVectorField("f", new float[] {0, 1}, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
if (r instanceof CodecReader codecReader) {
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
// if this search found any results it would raise NPE attempting to collect them in our
// null collector
knnVectorsReader.search("f", new float[] {1, 0}, null, null);
} else {
fail("reader is not CodecReader");
}
}
}
}
public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
VectorSimilarityFunction similarityFunction = randomSimilarity();
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
int dim = random().nextInt(64) + 1;
if (dim % 2 == 1) {
dim++;
}
List<float[]> vectors = new ArrayList<>(numVectors);
for (int i = 0; i < numVectors; i++) {
vectors.add(randomVector(dim));
}
ScalarQuantizer scalarQuantizer =
confidenceInterval != null && confidenceInterval == 0f
? ScalarQuantizer.fromVectorsAutoInterval(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
similarityFunction,
numVectors,
(byte) bits)
: ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
confidenceInterval == null
? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim)
: confidenceInterval,
numVectors,
(byte) bits);
float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) {
float[] vector = vectors.get(i);
if (normalize) {
float[] copy = new float[vector.length];
System.arraycopy(vector, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
vector = copy;
}
expectedVectors[i] = new byte[dim];
expectedCorrections[i] =
scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction);
}
float[] randomlyReusedVector = new float[dim];
try (Directory dir = newDirectory();
IndexWriter w =
new IndexWriter(
dir,
new IndexWriterConfig()
.setMaxBufferedDocs(numVectors + 1)
.setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH)
.setMergePolicy(NoMergePolicy.INSTANCE))) {
for (int i = 0; i < numVectors; i++) {
Document doc = new Document();
// randomly reuse a vector, this ensures the underlying codec doesn't rely on the array
// reference
final float[] v;
if (random().nextBoolean()) {
System.arraycopy(vectors.get(i), 0, randomlyReusedVector, 0, dim);
v = randomlyReusedVector;
} else {
v = vectors.get(i);
}
doc.add(new KnnFloatVectorField("f", v, similarityFunction));
w.addDocument(doc);
}
w.commit();
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
if (r instanceof CodecReader codecReader) {
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
knnVectorsReader = fieldsReader.getFieldReader("f");
}
if (knnVectorsReader instanceof Lucene99ScalarQuantizedVectorsReader quantizedReader) {
assertNotNull(quantizedReader.getQuantizationState("f"));
QuantizedByteVectorValues quantizedByteVectorValues =
quantizedReader.getQuantizedVectorValues("f");
int docId = -1;
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] vector = quantizedByteVectorValues.vectorValue();
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
for (int i = 0; i < dim; i++) {
assertEquals(vector[i], expectedVectors[docId][i]);
}
assertEquals(offset, expectedCorrections[docId], 0.00001f);
}
} else {
fail("reader is not Lucene99ScalarQuantizedVectorsReader");
}
} else {
fail("reader is not CodecReader");
}
}
}
}
public void testToString() {
FilterCodec customCodec =
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new Lucene99ScalarQuantizedVectorsFormat(0.9f, (byte) 4, false);
}
};
String expectedPattern =
"Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, confidenceInterval=0.9, bits=4, compress=false, flatVectorScorer=ScalarQuantizedVectorScorer(nonQuantizedDelegate=DefaultFlatVectorScorer()), rawVectorFormat=Lucene99FlatVectorsFormat(vectorsScorer=%s()))";
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
var memSegScorer =
format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer)));
}
public void testLimits() {
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(1.1f, 7, false));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(null, -1, false));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 5, false));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 9, false));
}
@Override
public void testRandomWithUpdatesAndGraph() {
// graph not supported
}
@Override
public void testSearchWithVisitedLimit() {
// search not supported
}
}

View File

@ -102,6 +102,7 @@ import junit.framework.AssertionFailedError;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.bitvectors.HnswBitVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.Field.Store;
@ -3223,11 +3224,16 @@ public abstract class LuceneTestCase extends Assert {
return true;
}
private static boolean supportsVectorSearch(KnnVectorsFormat format) {
return (format instanceof FlatVectorsFormat) == false;
}
protected static KnnVectorsFormat randomVectorFormat(VectorEncoding vectorEncoding) {
List<KnnVectorsFormat> availableFormats =
KnnVectorsFormat.availableKnnVectorsFormats().stream()
.map(KnnVectorsFormat::forName)
.filter(format -> supportsVectorEncoding(format, vectorEncoding))
.filter(format -> supportsVectorSearch(format))
.toList();
return RandomPicks.randomFrom(random(), availableFormats);
}