Move max vector dims limit to Codec (#12436)

Move vector max dimension limits enforcement into the default Codec's
KnnVectorsFormat implementation. This allows different implementation
of knn search algorithms define their own limits of a maximum
vector dimensions that they can handle.

Closes #12309
This commit is contained in:
Mayya Sharipova 2023-07-27 14:50:33 -04:00 committed by GitHub
parent 538b7d0ffe
commit 98320d7616
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 164 additions and 44 deletions

View File

@ -129,6 +129,8 @@ API Changes
* GITHUB#11248: IntBlockPool's SliceReader, SliceWriter, and all int slice functionality are moved out to MemoryIndex.
(Stefan Vodita)
* GITHUB#12436: Move max vector dims limit to Codec (Mayya Sharipova)
New Features
---------------------

View File

@ -32,6 +32,9 @@ import org.apache.lucene.util.NamedSPILoader;
*/
public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
/** The maximum number of vector dimensions */
public static final int DEFAULT_MAX_DIMENSIONS = 1024;
/**
* This static holder class prevents classloading deadlock by delaying init of doc values formats
* until needed.
@ -76,6 +79,19 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
/**
* Returns the maximum number of vector dimensions supported by this codec for the given field
* name
*
* <p>Codecs should override this method to specify the maximum number of dimensions they support.
*
* @param fieldName the field name
* @return the maximum number of vector dimensions.
*/
public int getMaxDimensions(String fieldName) {
return DEFAULT_MAX_DIMENSIONS;
}
/**
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
* support vectors.

View File

@ -185,6 +185,11 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
return new Lucene95HnswVectorsReader(state);
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
@Override
public String toString() {
return "Lucene95HnswVectorsFormat(name=Lucene95HnswVectorsFormat, maxConn="

View File

@ -80,6 +80,11 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
return new FieldsReader(state);
}
@Override
public int getMaxDimensions(String fieldName) {
return getKnnVectorsFormatForField(fieldName).getMaxDimensions(fieldName);
}
/**
* Returns the numeric vector format that should be used for writing new segments of <code>field
* </code>.

View File

@ -21,7 +21,6 @@ import java.util.Map;
import java.util.Objects;
import org.apache.lucene.analysis.Analyzer; // javadocs
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
@ -378,13 +377,6 @@ public class FieldType implements IndexableFieldType {
if (numDimensions <= 0) {
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
}
if (numDimensions > FloatVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"vector numDimensions must be <= FloatVectorValues.MAX_DIMENSIONS (="
+ FloatVectorValues.MAX_DIMENSIONS
+ "); got "
+ numDimensions);
}
this.vectorDimension = numDimensions;
this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
this.vectorEncoding = Objects.requireNonNull(encoding);

View File

@ -46,10 +46,6 @@ public class KnnByteVectorField extends Field {
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
}
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}

View File

@ -47,10 +47,6 @@ public class KnnFloatVectorField extends Field {
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
if (dimension > FloatVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS);
}
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}

View File

@ -28,9 +28,6 @@ import org.apache.lucene.search.DocIdSetIterator;
*/
public abstract class ByteVectorValues extends DocIdSetIterator {
/** The maximum length of a vector */
public static final int MAX_DIMENSIONS = 1024;
/** Sole constructor */
protected ByteVectorValues() {}

View File

@ -28,9 +28,6 @@ import org.apache.lucene.search.DocIdSetIterator;
*/
public abstract class FloatVectorValues extends DocIdSetIterator {
/** The maximum length of a vector */
public static final int MAX_DIMENSIONS = 1024;
/** Sole constructor */
protected FloatVectorValues() {}

View File

@ -621,6 +621,12 @@ final class IndexingChain implements Accountable {
final Sort indexSort = indexWriterConfig.getIndexSort();
validateIndexSortDVType(indexSort, pf.fieldName, s.docValuesType);
}
if (s.vectorDimension != 0) {
validateMaxVectorDimension(
pf.fieldName,
s.vectorDimension,
indexWriterConfig.getCodec().knnVectorsFormat().getMaxDimensions(pf.fieldName));
}
FieldInfo fi =
fieldInfos.add(
new FieldInfo(
@ -831,6 +837,20 @@ final class IndexingChain implements Accountable {
}
}
private static void validateMaxVectorDimension(
String fieldName, int vectorDim, int maxVectorDim) {
if (vectorDim > maxVectorDim) {
throw new IllegalArgumentException(
"Field ["
+ fieldName
+ "]"
+ "vector's dimensions must be <= ["
+ maxVectorDim
+ "]; got "
+ vectorDim);
}
}
private void validateIndexSortDVType(Sort indexSort, String fieldToValidate, DocValuesType dvType)
throws IOException {
for (SortField sortField : indexSort.getSort()) {

View File

@ -29,6 +29,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField;
@ -43,6 +44,9 @@ import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
@ -162,6 +166,50 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
}
}
public void testMaxDimensionsPerFieldFormat() throws IOException {
try (Directory directory = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
KnnVectorsFormat format1 =
new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100));
KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100);
iwc.setCodec(
new AssertingCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
if ("field1".equals(field)) {
return format1;
} else {
return format2;
}
}
});
try (IndexWriter writer = new IndexWriter(directory, iwc)) {
Document doc1 = new Document();
doc1.add(new KnnFloatVectorField("field1", new float[33]));
Exception exc =
expectThrows(IllegalArgumentException.class, () -> writer.addDocument(doc1));
assertTrue(exc.getMessage().contains("vector's dimensions must be <= [32]"));
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("field1", new float[32]));
doc2.add(new KnnFloatVectorField("field2", new float[33]));
writer.addDocument(doc2);
}
// Check that the vectors were written
try (IndexReader reader = DirectoryReader.open(directory)) {
IndexSearcher searcher = new IndexSearcher(reader);
Query query1 = new KnnFloatVectorQuery("field1", new float[32], 10);
TopDocs topDocs1 = searcher.search(query1, 1);
assertEquals(1, topDocs1.scoreDocs.length);
Query query2 = new KnnFloatVectorQuery("field2", new float[33], 10);
TopDocs topDocs2 = searcher.search(query2, 1);
assertEquals(1, topDocs2.scoreDocs.length);
}
}
}
private static class WriteRecordingKnnVectorsFormat extends KnnVectorsFormat {
private final KnnVectorsFormat delegate;
private final Set<String> fieldsWritten;
@ -216,4 +264,28 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
return delegate.fieldsReader(state);
}
}
private static class KnnVectorsFormatMaxDims32 extends KnnVectorsFormat {
private final KnnVectorsFormat delegate;
public KnnVectorsFormatMaxDims32(KnnVectorsFormat delegate) {
super(delegate.getName());
this.delegate = delegate;
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return delegate.fieldsWriter(state);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return delegate.fieldsReader(state);
}
@Override
public int getMaxDimensions(String fieldName) {
return 32;
}
}
}

View File

@ -32,7 +32,6 @@ import org.apache.lucene.document.TextField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.PointValues;
@ -280,7 +279,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
var builder = INDEX_PACKAGE_ACCESS.newFieldInfosBuilder(softDeletesField);
for (String field : fieldNames) {
IndexableFieldType fieldType = randomFieldType(random());
IndexableFieldType fieldType = randomFieldType(random(), field);
boolean storeTermVectors = false;
boolean storePayloads = false;
boolean omitNorms = false;
@ -319,7 +318,11 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
dir.close();
}
private IndexableFieldType randomFieldType(Random r) {
private int getVectorsMaxDimensions(String fieldName) {
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
}
private IndexableFieldType randomFieldType(Random r, String fieldName) {
FieldType type = new FieldType();
if (r.nextBoolean()) {
@ -352,7 +355,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
}
if (r.nextBoolean()) {
int dimension = 1 + r.nextInt(FloatVectorValues.MAX_DIMENSIONS);
int dimension = 1 + r.nextInt(getVectorsMaxDimensions(fieldName));
VectorSimilarityFunction similarityFunction =
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());

View File

@ -27,7 +27,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.document.NumericDocValuesField;
@ -86,6 +85,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
private int getVectorsMaxDimensions(String fieldName) {
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
}
public void testFieldConstructor() {
float[] v = new float[1];
KnnFloatVectorField field = new KnnFloatVectorField("f", v);
@ -101,14 +104,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
IllegalArgumentException.class,
() -> new KnnFloatVectorField("f", new float[1], (VectorSimilarityFunction) null));
expectThrows(IllegalArgumentException.class, () -> new KnnFloatVectorField("f", new float[0]));
expectThrows(
IllegalArgumentException.class,
() -> new KnnFloatVectorField("f", new float[FloatVectorValues.MAX_DIMENSIONS + 1]));
expectThrows(
IllegalArgumentException.class,
() ->
new KnnFloatVectorField(
"f", new float[FloatVectorValues.MAX_DIMENSIONS + 1], (FieldType) null));
}
public void testFieldSetValue() {
@ -478,18 +473,42 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
expectThrows(
IllegalArgumentException.class,
() ->
doc.add(
new KnnFloatVectorField(
"f",
new float[FloatVectorValues.MAX_DIMENSIONS + 1],
VectorSimilarityFunction.DOT_PRODUCT)));
doc.add(
new KnnFloatVectorField(
"f",
new float[getVectorsMaxDimensions("f") + 1],
VectorSimilarityFunction.DOT_PRODUCT));
Exception exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
assertTrue(
exc.getMessage()
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc2);
Document doc3 = new Document();
doc3.add(
new KnnFloatVectorField(
"f",
new float[getVectorsMaxDimensions("f") + 1],
VectorSimilarityFunction.DOT_PRODUCT));
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc3));
assertTrue(
exc.getMessage()
.contains("Inconsistency of field data structures across documents for field [f]"));
w.flush();
Document doc4 = new Document();
doc4.add(
new KnnFloatVectorField(
"f",
new float[getVectorsMaxDimensions("f") + 1],
VectorSimilarityFunction.DOT_PRODUCT));
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc4));
assertTrue(
exc.getMessage()
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
}
}