mirror of https://github.com/apache/lucene.git
Move max vector dims limit to Codec (#12436)
Move vector max dimension limits enforcement into the default Codec's KnnVectorsFormat implementation. This allows different implementation of knn search algorithms define their own limits of a maximum vector dimensions that they can handle. Closes #12309
This commit is contained in:
parent
538b7d0ffe
commit
98320d7616
|
@ -129,6 +129,8 @@ API Changes
|
|||
* GITHUB#11248: IntBlockPool's SliceReader, SliceWriter, and all int slice functionality are moved out to MemoryIndex.
|
||||
(Stefan Vodita)
|
||||
|
||||
* GITHUB#12436: Move max vector dims limit to Codec (Mayya Sharipova)
|
||||
|
||||
New Features
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -32,6 +32,9 @@ import org.apache.lucene.util.NamedSPILoader;
|
|||
*/
|
||||
public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
||||
|
||||
/** The maximum number of vector dimensions */
|
||||
public static final int DEFAULT_MAX_DIMENSIONS = 1024;
|
||||
|
||||
/**
|
||||
* This static holder class prevents classloading deadlock by delaying init of doc values formats
|
||||
* until needed.
|
||||
|
@ -76,6 +79,19 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
|
||||
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the maximum number of vector dimensions supported by this codec for the given field
|
||||
* name
|
||||
*
|
||||
* <p>Codecs should override this method to specify the maximum number of dimensions they support.
|
||||
*
|
||||
* @param fieldName the field name
|
||||
* @return the maximum number of vector dimensions.
|
||||
*/
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return DEFAULT_MAX_DIMENSIONS;
|
||||
}
|
||||
|
||||
/**
|
||||
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
|
||||
* support vectors.
|
||||
|
|
|
@ -185,6 +185,11 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
|
|||
return new Lucene95HnswVectorsReader(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return 1024;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Lucene95HnswVectorsFormat(name=Lucene95HnswVectorsFormat, maxConn="
|
||||
|
|
|
@ -80,6 +80,11 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
return new FieldsReader(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return getKnnVectorsFormatForField(fieldName).getMaxDimensions(fieldName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the numeric vector format that should be used for writing new segments of <code>field
|
||||
* </code>.
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.util.Map;
|
|||
import java.util.Objects;
|
||||
import org.apache.lucene.analysis.Analyzer; // javadocs
|
||||
import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexableFieldType;
|
||||
import org.apache.lucene.index.PointValues;
|
||||
|
@ -378,13 +377,6 @@ public class FieldType implements IndexableFieldType {
|
|||
if (numDimensions <= 0) {
|
||||
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
|
||||
}
|
||||
if (numDimensions > FloatVectorValues.MAX_DIMENSIONS) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector numDimensions must be <= FloatVectorValues.MAX_DIMENSIONS (="
|
||||
+ FloatVectorValues.MAX_DIMENSIONS
|
||||
+ "); got "
|
||||
+ numDimensions);
|
||||
}
|
||||
this.vectorDimension = numDimensions;
|
||||
this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
|
||||
this.vectorEncoding = Objects.requireNonNull(encoding);
|
||||
|
|
|
@ -46,10 +46,6 @@ public class KnnByteVectorField extends Field {
|
|||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
|
||||
throw new IllegalArgumentException(
|
||||
"cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
|
||||
}
|
||||
if (similarityFunction == null) {
|
||||
throw new IllegalArgumentException("similarity function must not be null");
|
||||
}
|
||||
|
|
|
@ -47,10 +47,6 @@ public class KnnFloatVectorField extends Field {
|
|||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
if (dimension > FloatVectorValues.MAX_DIMENSIONS) {
|
||||
throw new IllegalArgumentException(
|
||||
"cannot index vectors with dimension greater than " + FloatVectorValues.MAX_DIMENSIONS);
|
||||
}
|
||||
if (similarityFunction == null) {
|
||||
throw new IllegalArgumentException("similarity function must not be null");
|
||||
}
|
||||
|
|
|
@ -28,9 +28,6 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
*/
|
||||
public abstract class ByteVectorValues extends DocIdSetIterator {
|
||||
|
||||
/** The maximum length of a vector */
|
||||
public static final int MAX_DIMENSIONS = 1024;
|
||||
|
||||
/** Sole constructor */
|
||||
protected ByteVectorValues() {}
|
||||
|
||||
|
|
|
@ -28,9 +28,6 @@ import org.apache.lucene.search.DocIdSetIterator;
|
|||
*/
|
||||
public abstract class FloatVectorValues extends DocIdSetIterator {
|
||||
|
||||
/** The maximum length of a vector */
|
||||
public static final int MAX_DIMENSIONS = 1024;
|
||||
|
||||
/** Sole constructor */
|
||||
protected FloatVectorValues() {}
|
||||
|
||||
|
|
|
@ -621,6 +621,12 @@ final class IndexingChain implements Accountable {
|
|||
final Sort indexSort = indexWriterConfig.getIndexSort();
|
||||
validateIndexSortDVType(indexSort, pf.fieldName, s.docValuesType);
|
||||
}
|
||||
if (s.vectorDimension != 0) {
|
||||
validateMaxVectorDimension(
|
||||
pf.fieldName,
|
||||
s.vectorDimension,
|
||||
indexWriterConfig.getCodec().knnVectorsFormat().getMaxDimensions(pf.fieldName));
|
||||
}
|
||||
FieldInfo fi =
|
||||
fieldInfos.add(
|
||||
new FieldInfo(
|
||||
|
@ -831,6 +837,20 @@ final class IndexingChain implements Accountable {
|
|||
}
|
||||
}
|
||||
|
||||
private static void validateMaxVectorDimension(
|
||||
String fieldName, int vectorDim, int maxVectorDim) {
|
||||
if (vectorDim > maxVectorDim) {
|
||||
throw new IllegalArgumentException(
|
||||
"Field ["
|
||||
+ fieldName
|
||||
+ "]"
|
||||
+ "vector's dimensions must be <= ["
|
||||
+ maxVectorDim
|
||||
+ "]; got "
|
||||
+ vectorDim);
|
||||
}
|
||||
}
|
||||
|
||||
private void validateIndexSortDVType(Sort indexSort, String fieldToValidate, DocValuesType dvType)
|
||||
throws IOException {
|
||||
for (SortField sortField : indexSort.getSort()) {
|
||||
|
|
|
@ -29,6 +29,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
|
@ -43,6 +44,9 @@ import org.apache.lucene.index.NoMergePolicy;
|
|||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnFloatVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
|
@ -162,6 +166,50 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testMaxDimensionsPerFieldFormat() throws IOException {
|
||||
try (Directory directory = newDirectory()) {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||
KnnVectorsFormat format1 =
|
||||
new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100));
|
||||
KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100);
|
||||
iwc.setCodec(
|
||||
new AssertingCodec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
if ("field1".equals(field)) {
|
||||
return format1;
|
||||
} else {
|
||||
return format2;
|
||||
}
|
||||
}
|
||||
});
|
||||
try (IndexWriter writer = new IndexWriter(directory, iwc)) {
|
||||
Document doc1 = new Document();
|
||||
doc1.add(new KnnFloatVectorField("field1", new float[33]));
|
||||
Exception exc =
|
||||
expectThrows(IllegalArgumentException.class, () -> writer.addDocument(doc1));
|
||||
assertTrue(exc.getMessage().contains("vector's dimensions must be <= [32]"));
|
||||
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("field1", new float[32]));
|
||||
doc2.add(new KnnFloatVectorField("field2", new float[33]));
|
||||
writer.addDocument(doc2);
|
||||
}
|
||||
|
||||
// Check that the vectors were written
|
||||
try (IndexReader reader = DirectoryReader.open(directory)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
Query query1 = new KnnFloatVectorQuery("field1", new float[32], 10);
|
||||
TopDocs topDocs1 = searcher.search(query1, 1);
|
||||
assertEquals(1, topDocs1.scoreDocs.length);
|
||||
|
||||
Query query2 = new KnnFloatVectorQuery("field2", new float[33], 10);
|
||||
TopDocs topDocs2 = searcher.search(query2, 1);
|
||||
assertEquals(1, topDocs2.scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static class WriteRecordingKnnVectorsFormat extends KnnVectorsFormat {
|
||||
private final KnnVectorsFormat delegate;
|
||||
private final Set<String> fieldsWritten;
|
||||
|
@ -216,4 +264,28 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
return delegate.fieldsReader(state);
|
||||
}
|
||||
}
|
||||
|
||||
private static class KnnVectorsFormatMaxDims32 extends KnnVectorsFormat {
|
||||
private final KnnVectorsFormat delegate;
|
||||
|
||||
public KnnVectorsFormatMaxDims32(KnnVectorsFormat delegate) {
|
||||
super(delegate.getName());
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return delegate.fieldsWriter(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return delegate.fieldsReader(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return 32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,7 +32,6 @@ import org.apache.lucene.document.TextField;
|
|||
import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexableFieldType;
|
||||
import org.apache.lucene.index.PointValues;
|
||||
|
@ -280,7 +279,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
|||
var builder = INDEX_PACKAGE_ACCESS.newFieldInfosBuilder(softDeletesField);
|
||||
|
||||
for (String field : fieldNames) {
|
||||
IndexableFieldType fieldType = randomFieldType(random());
|
||||
IndexableFieldType fieldType = randomFieldType(random(), field);
|
||||
boolean storeTermVectors = false;
|
||||
boolean storePayloads = false;
|
||||
boolean omitNorms = false;
|
||||
|
@ -319,7 +318,11 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
|||
dir.close();
|
||||
}
|
||||
|
||||
private IndexableFieldType randomFieldType(Random r) {
|
||||
private int getVectorsMaxDimensions(String fieldName) {
|
||||
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
|
||||
}
|
||||
|
||||
private IndexableFieldType randomFieldType(Random r, String fieldName) {
|
||||
FieldType type = new FieldType();
|
||||
|
||||
if (r.nextBoolean()) {
|
||||
|
@ -352,7 +355,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
|||
}
|
||||
|
||||
if (r.nextBoolean()) {
|
||||
int dimension = 1 + r.nextInt(FloatVectorValues.MAX_DIMENSIONS);
|
||||
int dimension = 1 + r.nextInt(getVectorsMaxDimensions(fieldName));
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
|
||||
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());
|
||||
|
|
|
@ -27,7 +27,6 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
|
|||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
|
@ -86,6 +85,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
}
|
||||
|
||||
private int getVectorsMaxDimensions(String fieldName) {
|
||||
return Codec.getDefault().knnVectorsFormat().getMaxDimensions(fieldName);
|
||||
}
|
||||
|
||||
public void testFieldConstructor() {
|
||||
float[] v = new float[1];
|
||||
KnnFloatVectorField field = new KnnFloatVectorField("f", v);
|
||||
|
@ -101,14 +104,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
IllegalArgumentException.class,
|
||||
() -> new KnnFloatVectorField("f", new float[1], (VectorSimilarityFunction) null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new KnnFloatVectorField("f", new float[0]));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new KnnFloatVectorField("f", new float[FloatVectorValues.MAX_DIMENSIONS + 1]));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
new KnnFloatVectorField(
|
||||
"f", new float[FloatVectorValues.MAX_DIMENSIONS + 1], (FieldType) null));
|
||||
}
|
||||
|
||||
public void testFieldSetValue() {
|
||||
|
@ -478,18 +473,42 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
try (Directory dir = newDirectory();
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() ->
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"f",
|
||||
new float[FloatVectorValues.MAX_DIMENSIONS + 1],
|
||||
VectorSimilarityFunction.DOT_PRODUCT)));
|
||||
doc.add(
|
||||
new KnnFloatVectorField(
|
||||
"f",
|
||||
new float[getVectorsMaxDimensions("f") + 1],
|
||||
VectorSimilarityFunction.DOT_PRODUCT));
|
||||
Exception exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
|
||||
assertTrue(
|
||||
exc.getMessage()
|
||||
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
|
||||
|
||||
Document doc2 = new Document();
|
||||
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN));
|
||||
doc2.add(new KnnFloatVectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT));
|
||||
w.addDocument(doc2);
|
||||
|
||||
Document doc3 = new Document();
|
||||
doc3.add(
|
||||
new KnnFloatVectorField(
|
||||
"f",
|
||||
new float[getVectorsMaxDimensions("f") + 1],
|
||||
VectorSimilarityFunction.DOT_PRODUCT));
|
||||
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc3));
|
||||
assertTrue(
|
||||
exc.getMessage()
|
||||
.contains("Inconsistency of field data structures across documents for field [f]"));
|
||||
w.flush();
|
||||
|
||||
Document doc4 = new Document();
|
||||
doc4.add(
|
||||
new KnnFloatVectorField(
|
||||
"f",
|
||||
new float[getVectorsMaxDimensions("f") + 1],
|
||||
VectorSimilarityFunction.DOT_PRODUCT));
|
||||
exc = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc4));
|
||||
assertTrue(
|
||||
exc.getMessage()
|
||||
.contains("vector's dimensions must be <= [" + getVectorsMaxDimensions("f") + "]"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue