diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60FieldInfosFormat.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60FieldInfosFormat.java index e3b5a9b8902..4266fc3d2ad 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60FieldInfosFormat.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene60/Lucene60FieldInfosFormat.java @@ -30,7 +30,7 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.Directory; @@ -214,7 +214,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat { pointIndexDimensionCount, pointNumBytes, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, isSoftDeletesField); } catch (IllegalStateException e) { throw new CorruptIndexException( diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextFieldInfosFormat.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextFieldInfosFormat.java index 64e9e647b65..3a02b476ca0 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextFieldInfosFormat.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextFieldInfosFormat.java @@ -28,7 +28,7 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -158,7 +158,7 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat { SimpleTextUtil.readLine(input, scratch); assert StringHelper.startsWith(scratch.get(), VECTOR_SEARCH_STRATEGY); String scoreFunction = readString(VECTOR_SEARCH_STRATEGY.length, scratch); - VectorValues.SimilarityFunction vectorDistFunc = distanceFunction(scoreFunction); + VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction); SimpleTextUtil.readLine(input, scratch); assert StringHelper.startsWith(scratch.get(), SOFT_DELETES); @@ -201,8 +201,8 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat { return DocValuesType.valueOf(dvType); } - public VectorValues.SimilarityFunction distanceFunction(String scoreFunction) { - return VectorValues.SimilarityFunction.valueOf(scoreFunction); + public VectorSimilarityFunction distanceFunction(String scoreFunction) { + return VectorSimilarityFunction.valueOf(scoreFunction); } private String readString(int offset, BytesRefBuilder scratch) { diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorReader.java index 5f779a5c3c1..7c635a72350 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorReader.java @@ -81,9 +81,6 @@ public class SimpleTextVectorReader extends VectorReader { int fieldNumber = readInt(in, FIELD_NUMBER); while (fieldNumber != -1) { String fieldName = readString(in, FIELD_NAME); - String scoreFunctionName = readString(in, SCORE_FUNCTION); - VectorValues.SimilarityFunction similarityFunction = - VectorValues.SimilarityFunction.valueOf(scoreFunctionName); long vectorDataOffset = readLong(in, VECTOR_DATA_OFFSET); long vectorDataLength = readLong(in, VECTOR_DATA_LENGTH); int dimension = readInt(in, VECTOR_DIMENSION); @@ -94,9 +91,7 @@ public class SimpleTextVectorReader extends VectorReader { } assert fieldEntries.containsKey(fieldName) == false; fieldEntries.put( - fieldName, - new FieldEntry( - dimension, similarityFunction, vectorDataOffset, vectorDataLength, docIds)); + fieldName, new FieldEntry(dimension, vectorDataOffset, vectorDataLength, docIds)); fieldNumber = readInt(in, FIELD_NUMBER); } SimpleTextUtil.checkFooter(in); @@ -205,20 +200,13 @@ public class SimpleTextVectorReader extends VectorReader { private static class FieldEntry { final int dimension; - final VectorValues.SimilarityFunction similarityFunction; final long vectorDataOffset; final long vectorDataLength; final int[] ordToDoc; - FieldEntry( - int dimension, - VectorValues.SimilarityFunction similarityFunction, - long vectorDataOffset, - long vectorDataLength, - int[] ordToDoc) { + FieldEntry(int dimension, long vectorDataOffset, long vectorDataLength, int[] ordToDoc) { this.dimension = dimension; - this.similarityFunction = similarityFunction; this.vectorDataOffset = vectorDataOffset; this.vectorDataLength = vectorDataLength; this.ordToDoc = ordToDoc; @@ -260,11 +248,6 @@ public class SimpleTextVectorReader extends VectorReader { return entry.size(); } - @Override - public SimilarityFunction similarityFunction() { - return entry.similarityFunction; - } - @Override public float[] vectorValue() { return values[curOrd]; diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorWriter.java index 308bd53dfbe..1bba30cbd68 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextVectorWriter.java @@ -38,7 +38,6 @@ public class SimpleTextVectorWriter extends VectorWriter { static final BytesRef FIELD_NUMBER = new BytesRef("field-number "); static final BytesRef FIELD_NAME = new BytesRef("field-name "); - static final BytesRef SCORE_FUNCTION = new BytesRef("score-function "); static final BytesRef VECTOR_DATA_OFFSET = new BytesRef("vector-data-offset "); static final BytesRef VECTOR_DATA_LENGTH = new BytesRef("vector-data-length "); static final BytesRef VECTOR_DIMENSION = new BytesRef("vector-dimension "); @@ -96,7 +95,6 @@ public class SimpleTextVectorWriter extends VectorWriter { throws IOException { writeField(meta, FIELD_NUMBER, field.number); writeField(meta, FIELD_NAME, field.name); - writeField(meta, SCORE_FUNCTION, field.getVectorSimilarityFunction().name()); writeField(meta, VECTOR_DATA_OFFSET, vectorDataOffset); writeField(meta, VECTOR_DATA_LENGTH, vectorDataLength); writeField(meta, VECTOR_DIMENSION, field.getVectorDimension()); diff --git a/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/TestBlockWriter.java b/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/TestBlockWriter.java index b8859183e19..b9bb0fab67d 100644 --- a/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/TestBlockWriter.java +++ b/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/TestBlockWriter.java @@ -23,7 +23,7 @@ import org.apache.lucene.codecs.lucene90.MockTermStateFactory; import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.ByteBuffersIndexOutput; import org.apache.lucene.util.BytesRef; @@ -116,7 +116,7 @@ public class TestBlockWriter extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, true); } } diff --git a/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/sharedterms/TestSTBlockReader.java b/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/sharedterms/TestSTBlockReader.java index 78f9793aff9..b30a32d4134 100644 --- a/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/sharedterms/TestSTBlockReader.java +++ b/lucene/codecs/src/test/org/apache/lucene/codecs/uniformsplit/sharedterms/TestSTBlockReader.java @@ -41,7 +41,7 @@ import org.apache.lucene.index.ImpactsEnum; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.SegmentReadState; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.Directory; @@ -203,7 +203,7 @@ public class TestSTBlockReader extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java index 90d9acce64f..433009ee540 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/VectorWriter.java @@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.util.BytesRef; @@ -68,14 +69,14 @@ public abstract class VectorWriter implements Closeable { } List subs = new ArrayList<>(); int dimension = -1; - VectorValues.SimilarityFunction similarityFunction = null; + VectorSimilarityFunction similarityFunction = null; int nonEmptySegmentIndex = 0; for (int i = 0; i < mergeState.vectorReaders.length; i++) { VectorReader vectorReader = mergeState.vectorReaders[i]; if (vectorReader != null) { if (mergeFieldInfo != null && mergeFieldInfo.hasVectorValues()) { int segmentDimension = mergeFieldInfo.getVectorDimension(); - VectorValues.SimilarityFunction segmentSimilarityFunction = + VectorSimilarityFunction segmentSimilarityFunction = mergeFieldInfo.getVectorSimilarityFunction(); if (dimension == -1) { dimension = segmentDimension; @@ -238,11 +239,6 @@ public abstract class VectorWriter implements Closeable { return subs.get(0).values.dimension(); } - @Override - public SimilarityFunction similarityFunction() { - return subs.get(0).values.similarityFunction(); - } - class MergerRandomAccess implements RandomAccessVectorValues { private final List raSubs; @@ -269,11 +265,6 @@ public abstract class VectorWriter implements Closeable { return VectorValuesMerger.this.dimension(); } - @Override - public SimilarityFunction similarityFunction() { - return VectorValuesMerger.this.similarityFunction(); - } - @Override public float[] vectorValue(int target) throws IOException { int unmappedOrd = ordMap[target]; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java index d0c99267500..31ee09b7d4d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90FieldInfosFormat.java @@ -29,7 +29,7 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.SegmentInfo; -import org.apache.lucene.index.VectorValues.SimilarityFunction; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.Directory; @@ -102,8 +102,8 @@ import org.apache.lucene.store.IndexOutput; *
  • VectorDistFunction: a byte containing distance function used for similarity calculation. *
      *
    • 0: no distance function is defined for this field. - *
    • 1: EUCLIDEAN_HNSW distance. ({@link SimilarityFunction#EUCLIDEAN}) - *
    • 2: DOT_PRODUCT_HNSW score. ({@link SimilarityFunction#DOT_PRODUCT}) + *
    • 1: EUCLIDEAN_HNSW distance. ({@link VectorSimilarityFunction#EUCLIDEAN}) + *
    • 2: DOT_PRODUCT_HNSW score. ({@link VectorSimilarityFunction#DOT_PRODUCT}) *
    * * @@ -172,7 +172,7 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat { pointNumBytes = 0; } final int vectorDimension = input.readVInt(); - final SimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte()); + final VectorSimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte()); try { infos[i] = @@ -253,11 +253,11 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat { } } - private static SimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException { - if (b < 0 || b >= SimilarityFunction.values().length) { + private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException { + if (b < 0 || b >= VectorSimilarityFunction.values().length) { throw new CorruptIndexException("invalid distance function: " + b, input); } - return SimilarityFunction.values()[b]; + return VectorSimilarityFunction.values()[b]; } static { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorReader.java index 08e86943927..e5bdf4d7b56 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorReader.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -187,19 +188,18 @@ public final class Lucene90HnswVectorReader extends VectorReader { } } - private VectorValues.SimilarityFunction readSimilarityFunction(DataInput input) - throws IOException { + private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { int similarityFunctionId = input.readInt(); if (similarityFunctionId < 0 - || similarityFunctionId >= VectorValues.SimilarityFunction.values().length) { + || similarityFunctionId >= VectorSimilarityFunction.values().length) { throw new CorruptIndexException( "Invalid similarity function id: " + similarityFunctionId, input); } - return VectorValues.SimilarityFunction.values()[similarityFunctionId]; + return VectorSimilarityFunction.values()[similarityFunctionId]; } private FieldEntry readField(DataInput input) throws IOException { - VectorValues.SimilarityFunction similarityFunction = readSimilarityFunction(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); switch (similarityFunction) { case NONE: return new FieldEntry(input, similarityFunction); @@ -252,7 +252,14 @@ public final class Lucene90HnswVectorReader extends VectorReader { // use a seed that is fixed for the index so we get reproducible results for the same query final Random random = new Random(checksumSeed); NeighborQueue results = - HnswGraph.search(target, k, k, vectorValues, getGraphValues(fieldEntry), random); + HnswGraph.search( + target, + k, + k, + vectorValues, + fieldEntry.similarityFunction, + getGraphValues(fieldEntry), + random); int i = 0; ScoreDoc[] scoreDocs = new ScoreDoc[Math.min(results.size(), k)]; boolean reversed = fieldEntry.similarityFunction.reversed; @@ -292,7 +299,7 @@ public final class Lucene90HnswVectorReader extends VectorReader { } private KnnGraphValues getGraphValues(FieldEntry entry) throws IOException { - if (entry.similarityFunction != VectorValues.SimilarityFunction.NONE) { + if (entry.similarityFunction != VectorSimilarityFunction.NONE) { HnswGraphFieldEntry graphEntry = (HnswGraphFieldEntry) entry; IndexInput bytesSlice = vectorIndex.slice("graph-data", entry.indexDataOffset, entry.indexDataLength); @@ -310,7 +317,7 @@ public final class Lucene90HnswVectorReader extends VectorReader { private static class FieldEntry { final int dimension; - final VectorValues.SimilarityFunction similarityFunction; + final VectorSimilarityFunction similarityFunction; final long vectorDataOffset; final long vectorDataLength; @@ -318,8 +325,7 @@ public final class Lucene90HnswVectorReader extends VectorReader { final long indexDataLength; final int[] ordToDoc; - FieldEntry(DataInput input, VectorValues.SimilarityFunction similarityFunction) - throws IOException { + FieldEntry(DataInput input, VectorSimilarityFunction similarityFunction) throws IOException { this.similarityFunction = similarityFunction; vectorDataOffset = input.readVLong(); vectorDataLength = input.readVLong(); @@ -343,7 +349,7 @@ public final class Lucene90HnswVectorReader extends VectorReader { final long[] ordOffsets; - HnswGraphFieldEntry(DataInput input, VectorValues.SimilarityFunction similarityFunction) + HnswGraphFieldEntry(DataInput input, VectorSimilarityFunction similarityFunction) throws IOException { super(input, similarityFunction); ordOffsets = new long[size()]; @@ -389,11 +395,6 @@ public final class Lucene90HnswVectorReader extends VectorReader { return fieldEntry.size(); } - @Override - public SimilarityFunction similarityFunction() { - return fieldEntry.similarityFunction; - } - @Override public float[] vectorValue() throws IOException { dataIn.seek((long) ord * byteSize); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorWriter.java index 7cb8f7c60ee..c35ad2ce2a5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorWriter.java @@ -27,6 +27,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.RandomAccessVectorValuesProducer; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.BytesRef; @@ -126,11 +127,12 @@ public final class Lucene90HnswVectorWriter extends VectorWriter { long[] offsets = new long[count]; long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; long vectorIndexOffset = vectorIndex.getFilePointer(); - if (vectors.similarityFunction() != VectorValues.SimilarityFunction.NONE) { + if (fieldInfo.getVectorSimilarityFunction() != VectorSimilarityFunction.NONE) { if (vectors instanceof RandomAccessVectorValuesProducer) { writeGraph( vectorIndex, (RandomAccessVectorValuesProducer) vectors, + fieldInfo.getVectorSimilarityFunction(), vectorIndexOffset, offsets, count, @@ -150,7 +152,7 @@ public final class Lucene90HnswVectorWriter extends VectorWriter { vectorIndexLength, count, docIds); - if (vectors.similarityFunction() != VectorValues.SimilarityFunction.NONE) { + if (fieldInfo.getVectorSimilarityFunction() != VectorSimilarityFunction.NONE) { writeGraphOffsets(meta, offsets); } } @@ -196,6 +198,7 @@ public final class Lucene90HnswVectorWriter extends VectorWriter { private void writeGraph( IndexOutput graphData, RandomAccessVectorValuesProducer vectorValues, + VectorSimilarityFunction similarityFunction, long graphDataOffset, long[] offsets, int count, @@ -203,7 +206,8 @@ public final class Lucene90HnswVectorWriter extends VectorWriter { int beamWidth) throws IOException { HnswGraphBuilder hnswGraphBuilder = - new HnswGraphBuilder(vectorValues, maxConn, beamWidth, HnswGraphBuilder.randSeed); + new HnswGraphBuilder( + vectorValues, similarityFunction, maxConn, beamWidth, HnswGraphBuilder.randSeed); hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); HnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess()); diff --git a/lucene/core/src/java/org/apache/lucene/document/FieldType.java b/lucene/core/src/java/org/apache/lucene/document/FieldType.java index 19dd2aa42e8..f777ea57867 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FieldType.java +++ b/lucene/core/src/java/org/apache/lucene/document/FieldType.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexableFieldType; import org.apache.lucene.index.PointValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; /** Describes the properties of a field. */ @@ -42,8 +43,7 @@ public class FieldType implements IndexableFieldType { private int indexDimensionCount; private int dimensionNumBytes; private int vectorDimension; - private VectorValues.SimilarityFunction vectorSimilarityFunction = - VectorValues.SimilarityFunction.NONE; + private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.NONE; private Map attributes; /** Create a new mutable FieldType with all of the properties from ref */ @@ -371,7 +371,7 @@ public class FieldType implements IndexableFieldType { /** Enable vector indexing, with the specified number of dimensions and distance function. */ public void setVectorDimensionsAndSimilarityFunction( - int numDimensions, VectorValues.SimilarityFunction distFunc) { + int numDimensions, VectorSimilarityFunction distFunc) { checkIfFrozen(); if (numDimensions <= 0) { throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions); @@ -393,7 +393,7 @@ public class FieldType implements IndexableFieldType { } @Override - public VectorValues.SimilarityFunction vectorSimilarityFunction() { + public VectorSimilarityFunction vectorSimilarityFunction() { return vectorSimilarityFunction; } diff --git a/lucene/core/src/java/org/apache/lucene/document/VectorField.java b/lucene/core/src/java/org/apache/lucene/document/VectorField.java index 7e7d3f405c8..e4e6d13ba8f 100644 --- a/lucene/core/src/java/org/apache/lucene/document/VectorField.java +++ b/lucene/core/src/java/org/apache/lucene/document/VectorField.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; /** @@ -33,8 +34,7 @@ import org.apache.lucene.index.VectorValues; */ public class VectorField extends Field { - private static FieldType createType( - float[] v, VectorValues.SimilarityFunction similarityFunction) { + private static FieldType createType(float[] v, VectorSimilarityFunction similarityFunction) { if (v == null) { throw new IllegalArgumentException("vector value must not be null"); } @@ -63,7 +63,7 @@ public class VectorField extends Field { * @throws IllegalArgumentException if any parameter is null, or has dimension > 1024. */ public static FieldType createFieldType( - int dimension, VectorValues.SimilarityFunction similarityFunction) { + int dimension, VectorSimilarityFunction similarityFunction) { FieldType type = new FieldType(); type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction); type.freeze(); @@ -82,8 +82,7 @@ public class VectorField extends Field { * @throws IllegalArgumentException if any parameter is null, or the vector is empty or has * dimension > 1024. */ - public VectorField( - String name, float[] vector, VectorValues.SimilarityFunction similarityFunction) { + public VectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) { super(name, createType(vector, similarityFunction)); fieldsData = vector; } @@ -99,7 +98,7 @@ public class VectorField extends Field { * dimension > 1024. */ public VectorField(String name, float[] vector) { - this(name, vector, VectorValues.SimilarityFunction.EUCLIDEAN); + this(name, vector, VectorSimilarityFunction.EUCLIDEAN); } /** diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java b/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java index 656077fcb95..bc1b5073425 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java +++ b/lucene/core/src/java/org/apache/lucene/index/FieldInfo.java @@ -56,7 +56,7 @@ public final class FieldInfo { // if it is a positive value, it means this field indexes vectors private final int vectorDimension; - private final VectorValues.SimilarityFunction vectorSimilarityFunction; + private final VectorSimilarityFunction vectorSimilarityFunction; // whether this field is used as the soft-deletes field private final boolean softDeletesField; @@ -80,7 +80,7 @@ public final class FieldInfo { int pointIndexDimensionCount, int pointNumBytes, int vectorDimension, - VectorValues.SimilarityFunction vectorSimilarityFunction, + VectorSimilarityFunction vectorSimilarityFunction, boolean softDeletesField) { this.name = Objects.requireNonNull(name); this.number = number; @@ -202,7 +202,7 @@ public final class FieldInfo { throw new IllegalArgumentException( "vectorDimension must be >=0; got " + vectorDimension + " (field: '" + name + "')"); } - if (vectorDimension == 0 && vectorSimilarityFunction != VectorValues.SimilarityFunction.NONE) { + if (vectorDimension == 0 && vectorSimilarityFunction != VectorSimilarityFunction.NONE) { throw new IllegalArgumentException( "vector similarity function must be NONE when dimension = 0; got " + vectorSimilarityFunction @@ -355,9 +355,9 @@ public final class FieldInfo { static void verifySameVectorOptions( String fieldName, int vd1, - VectorValues.SimilarityFunction vsf1, + VectorSimilarityFunction vsf1, int vd2, - VectorValues.SimilarityFunction vsf2) { + VectorSimilarityFunction vsf2) { if (vd1 != vd2 || vsf1 != vsf2) { throw new IllegalArgumentException( "cannot change field \"" @@ -478,8 +478,8 @@ public final class FieldInfo { return vectorDimension; } - /** Returns {@link VectorValues.SimilarityFunction} for the field */ - public VectorValues.SimilarityFunction getVectorSimilarityFunction() { + /** Returns {@link VectorSimilarityFunction} for the field */ + public VectorSimilarityFunction getVectorSimilarityFunction() { return vectorSimilarityFunction; } diff --git a/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java b/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java index 0c8d4e5aad1..79f85baf8b5 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java +++ b/lucene/core/src/java/org/apache/lucene/index/FieldInfos.java @@ -299,9 +299,9 @@ public class FieldInfos implements Iterable { static final class FieldVectorProperties { final int numDimensions; - final VectorValues.SimilarityFunction similarityFunction; + final VectorSimilarityFunction similarityFunction; - FieldVectorProperties(int numDimensions, VectorValues.SimilarityFunction similarityFunction) { + FieldVectorProperties(int numDimensions, VectorSimilarityFunction similarityFunction) { this.numDimensions = numDimensions; this.similarityFunction = similarityFunction; } @@ -486,7 +486,7 @@ public class FieldInfos implements Iterable { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, (softDeletesFieldName != null && softDeletesFieldName.equals(fieldName))); addOrGet(fi); } @@ -567,7 +567,7 @@ public class FieldInfos implements Iterable { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, isSoftDeletesField); } diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java b/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java index c9e373d2d0c..ec95aefd3ba 100644 --- a/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java +++ b/lucene/core/src/java/org/apache/lucene/index/IndexableFieldType.java @@ -101,8 +101,8 @@ public interface IndexableFieldType { /** The number of dimensions of the field's vector value */ int vectorDimension(); - /** The {@link VectorValues.SimilarityFunction} of the field's vector value */ - VectorValues.SimilarityFunction vectorSimilarityFunction(); + /** The {@link VectorSimilarityFunction} of the field's vector value */ + VectorSimilarityFunction vectorSimilarityFunction(); /** * Attributes for the field type. diff --git a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java index c6ac6957f7c..2cca77aa71f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java +++ b/lucene/core/src/java/org/apache/lucene/index/IndexingChain.java @@ -1327,8 +1327,7 @@ final class IndexingChain implements Accountable { private int pointIndexDimensionCount = 0; private int pointNumBytes = 0; private int vectorDimension = 0; - private VectorValues.SimilarityFunction vectorSimilarityFunction = - VectorValues.SimilarityFunction.NONE; + private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.NONE; private static String errMsg = "Inconsistency of field data structures across documents for field "; @@ -1383,8 +1382,8 @@ final class IndexingChain implements Accountable { } } - void setVectors(VectorValues.SimilarityFunction similarityFunction, int dimension) { - if (vectorSimilarityFunction == VectorValues.SimilarityFunction.NONE) { + void setVectors(VectorSimilarityFunction similarityFunction, int dimension) { + if (vectorSimilarityFunction == VectorSimilarityFunction.NONE) { this.vectorDimension = dimension; this.vectorSimilarityFunction = similarityFunction; } else { @@ -1403,7 +1402,7 @@ final class IndexingChain implements Accountable { pointIndexDimensionCount = 0; pointNumBytes = 0; vectorDimension = 0; - vectorSimilarityFunction = VectorValues.SimilarityFunction.NONE; + vectorSimilarityFunction = VectorSimilarityFunction.NONE; } void assertSameSchema(FieldInfo fi) { diff --git a/lucene/core/src/java/org/apache/lucene/index/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/RandomAccessVectorValues.java index 562272f17e5..3f4ee0cb209 100644 --- a/lucene/core/src/java/org/apache/lucene/index/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/RandomAccessVectorValues.java @@ -33,9 +33,6 @@ public interface RandomAccessVectorValues { /** Return the dimension of the returned vector values */ int dimension(); - /** Return the similarity function used to compare these vectors */ - VectorValues.SimilarityFunction similarityFunction(); - /** * Return the vector value indexed at the given ordinal. The provided floating point array may be * shared and overwritten by subsequent calls to this method and {@link #binaryValue(int)}. diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java new file mode 100644 index 00000000000..5e67a374e1f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import static org.apache.lucene.util.VectorUtil.dotProduct; +import static org.apache.lucene.util.VectorUtil.squareDistance; + +import org.apache.lucene.codecs.VectorReader; + +/** + * Vector similarity function; used in search to return top K most similar vectors to a target + * vector. This is a label describing the method used during indexing and searching of the vectors + * in order to determine the nearest neighbors. + */ +public enum VectorSimilarityFunction { + + /** + * No similarity function is provided. Note: {@link VectorReader#search(String, float[], int)} is + * not supported for fields specifying this. + */ + NONE, + + /** HNSW graph built using Euclidean distance */ + EUCLIDEAN(true), + + /** HNSW graph buit using dot product */ + DOT_PRODUCT; + + /** + * If true, the scores associated with vector comparisons are in reverse order; that is, lower + * scores represent more similar vectors. Otherwise, if false, higher scores represent more + * similar vectors. + */ + public final boolean reversed; + + VectorSimilarityFunction(boolean reversed) { + this.reversed = reversed; + } + + VectorSimilarityFunction() { + reversed = false; + } + + /** + * Calculates a similarity score between the two vectors with a specified function. + * + * @param v1 a vector + * @param v2 another vector, of the same dimension + * @return the value of the similarity function applied to the two vectors + */ + public float compare(float[] v1, float[] v2) { + switch (this) { + case EUCLIDEAN: + return squareDistance(v1, v2); + case DOT_PRODUCT: + return dotProduct(v1, v2); + case NONE: + default: + throw new IllegalStateException("Incomparable similarity function: " + this); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java index ee5f763a720..77f7d273de5 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorValues.java @@ -16,11 +16,7 @@ */ package org.apache.lucene.index; -import static org.apache.lucene.util.VectorUtil.dotProduct; -import static org.apache.lucene.util.VectorUtil.squareDistance; - import java.io.IOException; -import org.apache.lucene.codecs.VectorReader; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BytesRef; @@ -50,9 +46,6 @@ public abstract class VectorValues extends DocIdSetIterator { */ public abstract int size(); - /** Return the similarity function used to compare these vectors */ - public abstract SimilarityFunction similarityFunction(); - /** * Return the vector value for the current document ID. It is illegal to call this method when the * iterator is not positioned: before advancing, or after failing to advance. The returned array @@ -74,60 +67,6 @@ public abstract class VectorValues extends DocIdSetIterator { throw new UnsupportedOperationException(); } - /** - * Vector similarity function; used in search to return top K most similar vectors to a target - * vector. This is a label describing the method used during indexing and searching of the vectors - * in order to determine the nearest neighbors. - */ - public enum SimilarityFunction { - - /** - * No similarity function is provided. Note: {@link VectorReader#search(String, float[], int)} - * is not supported for fields specifying this. - */ - NONE, - - /** HNSW graph built using Euclidean distance */ - EUCLIDEAN(true), - - /** HNSW graph buit using dot product */ - DOT_PRODUCT; - - /** - * If true, the scores associated with vector comparisons are in reverse order; that is, lower - * scores represent more similar vectors. Otherwise, if false, higher scores represent more - * similar vectors. - */ - public final boolean reversed; - - SimilarityFunction(boolean reversed) { - this.reversed = reversed; - } - - SimilarityFunction() { - reversed = false; - } - - /** - * Calculates a similarity score between the two vectors with a specified function. - * - * @param v1 a vector - * @param v2 another vector, of the same dimension - * @return the value of the similarity function applied to the two vectors - */ - public float compare(float[] v1, float[] v2) { - switch (this) { - case EUCLIDEAN: - return squareDistance(v1, v2); - case DOT_PRODUCT: - return dotProduct(v1, v2); - case NONE: - default: - throw new IllegalStateException("Incomparable similarity function: " + this); - } - } - } - /** * Represents the lack of vector values. It is returned by providers that do not support * VectorValues. @@ -145,11 +84,6 @@ public abstract class VectorValues extends DocIdSetIterator { return 0; } - @Override - public SimilarityFunction similarityFunction() { - return SimilarityFunction.NONE; - } - @Override public float[] vectorValue() { throw new IllegalStateException( diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java index 13e9b3da1d2..f8b08fa4142 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorValuesWriter.java @@ -108,11 +108,7 @@ class VectorValuesWriter { */ public void flush(Sorter.DocMap sortMap, VectorWriter vectorWriter) throws IOException { VectorValues vectorValues = - new BufferedVectorValues( - docsWithField, - vectors, - fieldInfo.getVectorDimension(), - fieldInfo.getVectorSimilarityFunction()); + new BufferedVectorValues(docsWithField, vectors, fieldInfo.getVectorDimension()); if (sortMap != null) { vectorWriter.writeField(fieldInfo, new SortingVectorValues(vectorValues, sortMap)); } else { @@ -189,11 +185,6 @@ class VectorValuesWriter { return delegate.size(); } - @Override - public SimilarityFunction similarityFunction() { - return delegate.similarityFunction(); - } - @Override public int advance(int target) throws IOException { throw new UnsupportedOperationException(); @@ -223,11 +214,6 @@ class VectorValuesWriter { return delegateRA.dimension(); } - @Override - public SimilarityFunction similarityFunction() { - return delegateRA.similarityFunction(); - } - @Override public float[] vectorValue(int targetOrd) throws IOException { return delegateRA.vectorValue(ordMap[targetOrd]); @@ -248,7 +234,6 @@ class VectorValuesWriter { // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; - final SimilarityFunction similarityFunction; final int dimension; final ByteBuffer buffer; @@ -259,15 +244,10 @@ class VectorValuesWriter { DocIdSetIterator docsWithFieldIter; int ord = -1; - BufferedVectorValues( - DocsWithFieldSet docsWithField, - List vectors, - int dimension, - SimilarityFunction similarityFunction) { + BufferedVectorValues(DocsWithFieldSet docsWithField, List vectors, int dimension) { this.docsWithField = docsWithField; this.vectors = vectors; this.dimension = dimension; - this.similarityFunction = similarityFunction; buffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); binaryValue = new BytesRef(buffer.array()); raBuffer = ByteBuffer.allocate(dimension * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); @@ -277,7 +257,7 @@ class VectorValuesWriter { @Override public RandomAccessVectorValues randomAccess() { - return new BufferedVectorValues(docsWithField, vectors, dimension, similarityFunction); + return new BufferedVectorValues(docsWithField, vectors, dimension); } @Override @@ -290,11 +270,6 @@ class VectorValuesWriter { return vectors.size(); } - @Override - public SimilarityFunction similarityFunction() { - return similarityFunction; - } - @Override public BytesRef binaryValue() { buffer.asFloatBuffer().put(vectorValue()); diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java index 6fb7caad69c..49f2c95d1f8 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraph.java @@ -25,7 +25,7 @@ import java.util.List; import java.util.Random; import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.RandomAccessVectorValues; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.SparseFixedBitSet; /** @@ -91,10 +91,10 @@ public final class HnswGraph extends KnnGraphValues { int topK, int numSeed, RandomAccessVectorValues vectors, + VectorSimilarityFunction similarityFunction, KnnGraphValues graphValues, Random random) throws IOException { - VectorValues.SimilarityFunction similarityFunction = vectors.similarityFunction(); int size = graphValues.size(); // MIN heap, holding the top results diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java index 16d4fc6f8fe..b99cb02f675 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphBuilder.java @@ -22,7 +22,7 @@ import java.util.Locale; import java.util.Random; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.InfoStream; /** @@ -42,7 +42,7 @@ public final class HnswGraphBuilder { private final int beamWidth; private final NeighborArray scratch; - private final VectorValues.SimilarityFunction similarityFunction; + private final VectorSimilarityFunction similarityFunction; private final RandomAccessVectorValues vectorValues; private final Random random; private final BoundsChecker bound; @@ -67,11 +67,15 @@ public final class HnswGraphBuilder { * to ensure repeatable construction. */ public HnswGraphBuilder( - RandomAccessVectorValuesProducer vectors, int maxConn, int beamWidth, long seed) { + RandomAccessVectorValuesProducer vectors, + VectorSimilarityFunction similarityFunction, + int maxConn, + int beamWidth, + long seed) { vectorValues = vectors.randomAccess(); buildVectors = vectors.randomAccess(); - similarityFunction = vectorValues.similarityFunction(); - if (similarityFunction == VectorValues.SimilarityFunction.NONE) { + this.similarityFunction = similarityFunction; + if (similarityFunction == VectorSimilarityFunction.NONE) { throw new IllegalStateException("No distance function"); } if (maxConn <= 0) { @@ -133,7 +137,8 @@ public final class HnswGraphBuilder { /** Inserts a doc with vector value to the graph */ void addGraphNode(float[] value) throws IOException { NeighborQueue candidates = - HnswGraph.search(value, beamWidth, beamWidth, vectorValues, hnsw, random); + HnswGraph.search( + value, beamWidth, beamWidth, vectorValues, similarityFunction, hnsw, random); int node = hnsw.addNode(); diff --git a/lucene/core/src/test/org/apache/lucene/document/TestPerFieldConsistency.java b/lucene/core/src/test/org/apache/lucene/document/TestPerFieldConsistency.java index 98878da7030..140e42ae797 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestPerFieldConsistency.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestPerFieldConsistency.java @@ -32,7 +32,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.NoMergePolicy; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.LuceneTestCase; @@ -88,10 +88,10 @@ public class TestPerFieldConsistency extends LuceneTestCase { } private static Field randomVectorField(Random random, String fieldName) { - VectorValues.SimilarityFunction similarityFunction = - RandomPicks.randomFrom(random, VectorValues.SimilarityFunction.values()); - while (similarityFunction == VectorValues.SimilarityFunction.NONE) { - similarityFunction = RandomPicks.randomFrom(random, VectorValues.SimilarityFunction.values()); + VectorSimilarityFunction similarityFunction = + RandomPicks.randomFrom(random, VectorSimilarityFunction.values()); + while (similarityFunction == VectorSimilarityFunction.NONE) { + similarityFunction = RandomPicks.randomFrom(random, VectorSimilarityFunction.values()); } float[] values = new float[randomIntBetween(1, 10)]; for (int i = 0; i < values.length; i++) { diff --git a/lucene/core/src/test/org/apache/lucene/index/TestCodecs.java b/lucene/core/src/test/org/apache/lucene/index/TestCodecs.java index b8f7e069975..2b7c06f27eb 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestCodecs.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestCodecs.java @@ -112,7 +112,7 @@ public class TestCodecs extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false)); } this.terms = terms; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestDocumentWriter.java b/lucene/core/src/test/org/apache/lucene/index/TestDocumentWriter.java index 991e6bbd1d8..045f3cd0ae9 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestDocumentWriter.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestDocumentWriter.java @@ -381,7 +381,6 @@ public class TestDocumentWriter extends LuceneTestCase { public void testRAMUsageVector() throws IOException { doTestRAMUsage( field -> - new VectorField( - field, new float[] {1, 2, 3, 4}, VectorValues.SimilarityFunction.EUCLIDEAN)); + new VectorField(field, new float[] {1, 2, 3, 4}, VectorSimilarityFunction.EUCLIDEAN)); } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java index 3baa5649b01..5c3a122268b 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldInfos.java @@ -260,7 +260,7 @@ public class TestFieldInfos extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false)); } int idx = @@ -279,7 +279,7 @@ public class TestFieldInfos extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false)); assertEquals("Field numbers 0 through 9 were allocated", 10, idx); @@ -300,7 +300,7 @@ public class TestFieldInfos extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false)); assertEquals("Field numbers should reset after clear()", 0, idx); } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestFieldsReader.java b/lucene/core/src/test/org/apache/lucene/index/TestFieldsReader.java index 3563929457b..04ee1a1d384 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestFieldsReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestFieldsReader.java @@ -63,7 +63,7 @@ public class TestFieldsReader extends LuceneTestCase { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, field.name().equals(softDeletesFieldName))); } dir = newDirectory(); diff --git a/lucene/core/src/test/org/apache/lucene/index/TestIndexableField.java b/lucene/core/src/test/org/apache/lucene/index/TestIndexableField.java index 543e8088663..b68b0b38405 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestIndexableField.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestIndexableField.java @@ -113,8 +113,8 @@ public class TestIndexableField extends LuceneTestCase { } @Override - public VectorValues.SimilarityFunction vectorSimilarityFunction() { - return VectorValues.SimilarityFunction.NONE; + public VectorSimilarityFunction vectorSimilarityFunction() { + return VectorSimilarityFunction.NONE; } @Override diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 1875bd83436..98b3652fca0 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -38,7 +38,6 @@ import org.apache.lucene.document.FieldType; import org.apache.lucene.document.SortedDocValuesField; import org.apache.lucene.document.StringField; import org.apache.lucene.document.VectorField; -import org.apache.lucene.index.VectorValues.SimilarityFunction; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; @@ -58,7 +57,7 @@ public class TestKnnGraph extends LuceneTestCase { private static int maxConn = Lucene90HnswVectorFormat.DEFAULT_MAX_CONN; private Codec codec; - private SimilarityFunction similarityFunction; + private VectorSimilarityFunction similarityFunction; @Before public void setup() { @@ -76,8 +75,8 @@ public class TestKnnGraph extends LuceneTestCase { } }; - int similarity = random().nextInt(SimilarityFunction.values().length - 1) + 1; - similarityFunction = SimilarityFunction.values()[similarity]; + int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1; + similarityFunction = VectorSimilarityFunction.values()[similarity]; } @After @@ -227,7 +226,7 @@ public class TestKnnGraph extends LuceneTestCase { /** Verify that searching does something reasonable */ public void testSearch() throws Exception { // We can't use dot product here since the vectors are laid out on a grid, not a sphere. - similarityFunction = SimilarityFunction.EUCLIDEAN; + similarityFunction = VectorSimilarityFunction.EUCLIDEAN; IndexWriterConfig config = newIndexWriterConfig(); config.setCodec(codec); // test is not compatible with simpletext try (Directory dir = newDirectory(); @@ -454,7 +453,8 @@ public class TestKnnGraph extends LuceneTestCase { add(iw, id, vector, similarityFunction); } - private void add(IndexWriter iw, int id, float[] vector, SimilarityFunction similarityFunction) + private void add( + IndexWriter iw, int id, float[] vector, VectorSimilarityFunction similarityFunction) throws IOException { Document doc = new Document(); if (vector != null) { diff --git a/lucene/core/src/test/org/apache/lucene/index/TestPendingSoftDeletes.java b/lucene/core/src/test/org/apache/lucene/index/TestPendingSoftDeletes.java index 62420ba501f..bd6e438b49c 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestPendingSoftDeletes.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestPendingSoftDeletes.java @@ -17,7 +17,7 @@ package org.apache.lucene.index; -import static org.apache.lucene.index.VectorValues.SimilarityFunction.NONE; +import static org.apache.lucene.index.VectorSimilarityFunction.NONE; import java.io.IOException; import java.util.Arrays; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java index ed91db2c35d..34eec1cb206 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/KnnGraphTester.java @@ -53,7 +53,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; -import org.apache.lucene.index.VectorValues; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; @@ -73,8 +73,8 @@ public class KnnGraphTester { private static final String KNN_FIELD = "knn"; private static final String ID_FIELD = "id"; - private static final VectorValues.SimilarityFunction SIMILARITY_FUNCTION = - VectorValues.SimilarityFunction.DOT_PRODUCT; + private static final VectorSimilarityFunction SIMILARITY_FUNCTION = + VectorSimilarityFunction.DOT_PRODUCT; private int numDocs; private int dim; @@ -251,7 +251,8 @@ public class KnnGraphTester { private void dumpGraph(Path docsPath) throws IOException { try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) { RandomAccessVectorValues values = vectors.randomAccess(); - HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, 0); + HnswGraphBuilder builder = + new HnswGraphBuilder(vectors, SIMILARITY_FUNCTION, maxConn, beamWidth, 0); // start at node 1 for (int i = 1; i < numDocs; i++) { builder.addGraphNode(values.vectorValue(i)); @@ -580,8 +581,7 @@ public class KnnGraphTester { iwc.setRAMBufferSizeMB(1994d); // iwc.setMaxBufferedDocs(10000); - FieldType fieldType = - VectorField.createFieldType(dim, VectorValues.SimilarityFunction.DOT_PRODUCT); + FieldType fieldType = VectorField.createFieldType(dim, VectorSimilarityFunction.DOT_PRODUCT); if (quiet == false) { iwc.setInfoStream(new PrintStreamInfoStream(System.out)); System.out.println("creating index in " + indexPath); @@ -675,11 +675,6 @@ public class KnnGraphTester { return dim; } - @Override - public VectorValues.SimilarityFunction similarityFunction() { - return SIMILARITY_FUNCTION; - } - @Override public float[] vectorValue(int targetOrd) { int pos = targetOrd * dim; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index c4aa88569f6..ee52a111cdf 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -30,13 +30,11 @@ class MockVectorValues extends VectorValues protected final int dimension; protected final float[][] denseValues; protected final float[][] values; - protected final SimilarityFunction similarityFunction; private final int numVectors; private int pos = -1; - MockVectorValues(SimilarityFunction similarityFunction, float[][] values) { - this.similarityFunction = similarityFunction; + MockVectorValues(float[][] values) { this.dimension = values[0].length; this.values = values; int maxDoc = values.length; @@ -52,7 +50,7 @@ class MockVectorValues extends VectorValues } public MockVectorValues copy() { - return new MockVectorValues(similarityFunction, values); + return new MockVectorValues(values); } @Override @@ -60,11 +58,6 @@ class MockVectorValues extends VectorValues return numVectors; } - @Override - public SimilarityFunction similarityFunction() { - return similarityFunction; - } - @Override public int dimension() { return dimension; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java index 7a1d6cafabe..bb3deaeec31 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnsw.java @@ -41,6 +41,7 @@ import org.apache.lucene.index.KnnGraphValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomAccessVectorValues; import org.apache.lucene.index.RandomAccessVectorValuesProducer; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.store.Directory; import org.apache.lucene.util.ArrayUtil; @@ -61,7 +62,11 @@ public class TestHnsw extends LuceneTestCase { int maxConn = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 5; long seed = random().nextLong(); - HnswGraphBuilder builder = new HnswGraphBuilder(vectors, maxConn, beamWidth, seed); + VectorSimilarityFunction similarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; + HnswGraphBuilder builder = + new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, seed); HnswGraph hnsw = builder.build(vectors); // Recreate the graph while indexing with the same random seed and write it out @@ -87,7 +92,7 @@ public class TestHnsw extends LuceneTestCase { indexedDoc++; } Document doc = new Document(); - doc.add(new VectorField("field", v2.vectorValue(), v2.similarityFunction)); + doc.add(new VectorField("field", v2.vectorValue())); doc.add(new StoredField("id", v2.docID())); iw.addDocument(doc); nVec++; @@ -97,7 +102,6 @@ public class TestHnsw extends LuceneTestCase { try (IndexReader reader = DirectoryReader.open(dir)) { for (LeafReaderContext ctx : reader.leaves()) { VectorValues values = ctx.reader().getVectorValues("field"); - assertEquals(vectors.similarityFunction, values.similarityFunction()); assertEquals(dim, values.dimension()); assertEquals(nVec, values.size()); assertEquals(indexedDoc, ctx.reader().maxDoc()); @@ -121,11 +125,20 @@ public class TestHnsw extends LuceneTestCase { public void testAknnDiverse() throws IOException { int nDoc = 100; CircularVectorValues vectors = new CircularVectorValues(nDoc); - HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 16, 100, random().nextInt()); + HnswGraphBuilder builder = + new HnswGraphBuilder( + vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt()); HnswGraph hnsw = builder.build(vectors); // run some searches NeighborQueue nn = - HnswGraph.search(new float[] {1, 0}, 10, 5, vectors.randomAccess(), hnsw, random()); + HnswGraph.search( + new float[] {1, 0}, + 10, + 5, + vectors.randomAccess(), + VectorSimilarityFunction.DOT_PRODUCT, + hnsw, + random()); int sum = 0; for (int node : nn.nodes()) { sum += node; @@ -168,20 +181,31 @@ public class TestHnsw extends LuceneTestCase { } public void testHnswGraphBuilderInvalid() { - expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, 0, 0, 0)); + expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0)); expectThrows( IllegalArgumentException.class, - () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 0, 10, 0)); + () -> + new HnswGraphBuilder( + new RandomVectorValues(1, 1, random()), + VectorSimilarityFunction.EUCLIDEAN, + 0, + 10, + 0)); expectThrows( IllegalArgumentException.class, - () -> new HnswGraphBuilder(new RandomVectorValues(1, 1, random()), 10, 0, 0)); + () -> + new HnswGraphBuilder( + new RandomVectorValues(1, 1, random()), + VectorSimilarityFunction.EUCLIDEAN, + 10, + 0, + 0)); } public void testDiversity() throws IOException { // Some carefully checked test cases with simple 2d vectors on the unit circle: MockVectorValues vectors = new MockVectorValues( - VectorValues.SimilarityFunction.DOT_PRODUCT, new float[][] { unitVector2d(0.5), unitVector2d(0.75), @@ -191,7 +215,9 @@ public class TestHnsw extends LuceneTestCase { unitVector2d(0.77), }); // First add nodes until everybody gets a full neighbor list - HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 2, 10, random().nextInt()); + HnswGraphBuilder builder = + new HnswGraphBuilder( + vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt()); // node 0 is added by the builder constructor // builder.addGraphNode(vectors.vectorValue(0)); builder.addGraphNode(vectors.vectorValue(1)); @@ -247,18 +273,22 @@ public class TestHnsw extends LuceneTestCase { int dim = atLeast(10); int topK = 5; RandomVectorValues vectors = new RandomVectorValues(size, dim, random()); - HnswGraphBuilder builder = new HnswGraphBuilder(vectors, 10, 30, random().nextLong()); + VectorSimilarityFunction similarityFunction = + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length - 1) + 1]; + HnswGraphBuilder builder = + new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong()); HnswGraph hnsw = builder.build(vectors); int totalMatches = 0; for (int i = 0; i < 100; i++) { float[] query = randomVector(random(), dim); - NeighborQueue actual = HnswGraph.search(query, topK, 100, vectors, hnsw, random()); - NeighborQueue expected = new NeighborQueue(topK, vectors.similarityFunction.reversed); + NeighborQueue actual = + HnswGraph.search(query, topK, 100, vectors, similarityFunction, hnsw, random()); + NeighborQueue expected = new NeighborQueue(topK, similarityFunction.reversed); for (int j = 0; j < size; j++) { float[] v = vectors.vectorValue(j); if (v != null) { - expected.insertWithOverflow( - j, vectors.similarityFunction.compare(query, vectors.vectorValue(j))); + expected.insertWithOverflow(j, similarityFunction.compare(query, vectors.vectorValue(j))); } } assertEquals(topK, actual.size()); @@ -304,11 +334,6 @@ public class TestHnsw extends LuceneTestCase { return new CircularVectorValues(size); } - @Override - public SimilarityFunction similarityFunction() { - return SimilarityFunction.DOT_PRODUCT; - } - @Override public int dimension() { return 2; @@ -409,13 +434,11 @@ public class TestHnsw extends LuceneTestCase { static class RandomVectorValues extends MockVectorValues { RandomVectorValues(int size, int dimension, Random random) { - super( - SimilarityFunction.values()[random.nextInt(SimilarityFunction.values().length - 1) + 1], - createRandomVectors(size, dimension, random)); + super(createRandomVectors(size, dimension, random)); } RandomVectorValues(RandomVectorValues other) { - super(other.similarityFunction, other.values); + super(other.values); } @Override diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java index a6f68cb3ba1..34916cc9f5d 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorValues; import org.apache.lucene.search.TopDocs; import org.apache.lucene.util.Bits; @@ -97,7 +98,7 @@ public class TermVectorLeafReader extends LeafReader { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false); fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo}); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java index 4781b644fe4..760f1e64c65 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/codecs/asserting/AssertingVectorFormat.java @@ -93,7 +93,6 @@ public class AssertingVectorFormat extends VectorFormat { assert values.docID() == -1; assert values.size() >= 0; assert values.dimension() > 0; - assert values.similarityFunction() != null; } return values; } diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/BaseFieldInfoFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/index/BaseFieldInfoFormatTestCase.java index b17e671ac11..cff939ac247 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/index/BaseFieldInfoFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/index/BaseFieldInfoFormatTestCase.java @@ -341,8 +341,8 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes if (r.nextBoolean()) { int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS); - VectorValues.SimilarityFunction similarityFunction = - RandomPicks.randomFrom(r, VectorValues.SimilarityFunction.values()); + VectorSimilarityFunction similarityFunction = + RandomPicks.randomFrom(r, VectorSimilarityFunction.values()); type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction); } @@ -412,7 +412,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false); } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java index e1e4a896dc6..2436e90633a 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/index/BaseVectorFormatTestCase.java @@ -49,15 +49,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa @Override protected void addRandomFields(Document doc) { - doc.add(new VectorField("v2", randomVector(30), VectorValues.SimilarityFunction.NONE)); + doc.add(new VectorField("v2", randomVector(30), VectorSimilarityFunction.NONE)); } public void testFieldConstructor() { float[] v = new float[1]; VectorField field = new VectorField("f", v); assertEquals(1, field.fieldType().vectorDimension()); - assertEquals( - VectorValues.SimilarityFunction.EUCLIDEAN, field.fieldType().vectorSimilarityFunction()); + assertEquals(VectorSimilarityFunction.EUCLIDEAN, field.fieldType().vectorSimilarityFunction()); assertSame(v, field.vectorValue()); } @@ -66,7 +65,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa expectThrows(IllegalArgumentException.class, () -> new VectorField("f", null)); expectThrows( IllegalArgumentException.class, - () -> new VectorField("f", new float[1], (VectorValues.SimilarityFunction) null)); + () -> new VectorField("f", new float[1], (VectorSimilarityFunction) null)); expectThrows(IllegalArgumentException.class, () -> new VectorField("f", new float[0])); expectThrows( IllegalArgumentException.class, @@ -91,11 +90,11 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[3], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc2.add(new VectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); String errMsg = @@ -107,12 +106,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.commit(); Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[3], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc2.add(new VectorField("f", new float[3], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); String errMsg = @@ -127,11 +126,11 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc2.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); String errMsg = @@ -143,12 +142,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.commit(); Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc2.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2)); String errMsg = @@ -162,13 +161,13 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) { Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[1], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2)); assertEquals( @@ -183,13 +182,13 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir, newIndexWriterConfig())) { Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc2.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2)); assertEquals( @@ -203,7 +202,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa public void testAddIndexesDirectory0() throws Exception { String fieldName = "field"; Document doc = new Document(); - doc.add(new VectorField(fieldName, new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.DOT_PRODUCT)); try (Directory dir = newDirectory(); Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { @@ -231,8 +230,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { w.addDocument(doc); } - doc.add( - new VectorField(fieldName, new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.DOT_PRODUCT)); try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { w2.addDocument(doc); w2.addIndexes(dir); @@ -252,7 +250,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa String fieldName = "field"; float[] vector = new float[1]; Document doc = new Document(); - doc.add(new VectorField(fieldName, vector, VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField(fieldName, vector, VectorSimilarityFunction.DOT_PRODUCT)); try (Directory dir = newDirectory(); Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { @@ -283,12 +281,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[5], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT)); w2.addDocument(doc); IllegalArgumentException expected = expectThrows( @@ -306,12 +304,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN)); w2.addDocument(doc); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir)); @@ -328,12 +326,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[5], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT)); w2.addDocument(doc); try (DirectoryReader r = DirectoryReader.open(dir)) { IllegalArgumentException expected = @@ -354,12 +352,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN)); w2.addDocument(doc); try (DirectoryReader r = DirectoryReader.open(dir)) { IllegalArgumentException expected = @@ -380,12 +378,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[5], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[5], VectorSimilarityFunction.DOT_PRODUCT)); w2.addDocument(doc); try (DirectoryReader r = DirectoryReader.open(dir)) { IllegalArgumentException expected = @@ -404,12 +402,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Directory dir2 = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w2 = new IndexWriter(dir2, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.EUCLIDEAN)); w2.addDocument(doc); try (DirectoryReader r = DirectoryReader.open(dir)) { IllegalArgumentException expected = @@ -427,8 +425,8 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); IllegalArgumentException expected = expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc)); assertEquals( @@ -448,10 +446,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa new VectorField( "f", new float[VectorValues.MAX_DIMENSIONS + 1], - VectorValues.SimilarityFunction.DOT_PRODUCT))); + VectorSimilarityFunction.DOT_PRODUCT))); Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[1], VectorValues.SimilarityFunction.EUCLIDEAN)); + doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc2); } } @@ -463,13 +461,11 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Exception e = expectThrows( IllegalArgumentException.class, - () -> - doc.add( - new VectorField("f", new float[0], VectorValues.SimilarityFunction.NONE))); + () -> doc.add(new VectorField("f", new float[0], VectorSimilarityFunction.NONE))); assertEquals("cannot index an empty vector", e.getMessage()); Document doc2 = new Document(); - doc2.add(new VectorField("f", new float[1], VectorValues.SimilarityFunction.NONE)); + doc2.add(new VectorField("f", new float[1], VectorSimilarityFunction.NONE)); w.addDocument(doc2); } } @@ -479,14 +475,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } IndexWriterConfig iwc = newIndexWriterConfig(); iwc.setCodec(Codec.forName("SimpleText")); try (IndexWriter w = new IndexWriter(dir, iwc)) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.forceMerge(1); } @@ -500,12 +496,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, iwc)) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); } try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("f", new float[4], VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.forceMerge(1); } @@ -513,8 +509,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa } public void testInvalidVectorFieldUsage() { - VectorField field = - new VectorField("field", new float[2], VectorValues.SimilarityFunction.NONE); + VectorField field = new VectorField("field", new float[2], VectorSimilarityFunction.NONE); expectThrows(IllegalArgumentException.class, () -> field.setIntValue(14)); @@ -528,8 +523,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); doc.add(new StringField("id", "0", Field.Store.NO)); - doc.add( - new VectorField("v", new float[] {2, 3, 5}, VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("v", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.addDocument(new Document()); w.commit(); @@ -554,16 +548,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); doc.add(new StringField("id", "0", Field.Store.NO)); - doc.add( - new VectorField( - "v0", new float[] {2, 3, 5}, VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("v0", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.commit(); doc = new Document(); - doc.add( - new VectorField( - "v1", new float[] {2, 3, 5}, VectorValues.SimilarityFunction.DOT_PRODUCT)); + doc.add(new VectorField("v1", new float[] {2, 3, 5}, VectorSimilarityFunction.DOT_PRODUCT)); w.addDocument(doc); w.forceMerge(1); } @@ -575,13 +565,12 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa int[] fieldDocCounts = new int[numFields]; float[] fieldTotals = new float[numFields]; int[] fieldDims = new int[numFields]; - VectorValues.SimilarityFunction[] fieldSearchStrategies = - new VectorValues.SimilarityFunction[numFields]; + VectorSimilarityFunction[] fieldSearchStrategies = new VectorSimilarityFunction[numFields]; for (int i = 0; i < numFields; i++) { fieldDims[i] = random().nextInt(20) + 1; fieldSearchStrategies[i] = - VectorValues.SimilarityFunction.values()[ - random().nextInt(VectorValues.SimilarityFunction.values().length)]; + VectorSimilarityFunction.values()[ + random().nextInt(VectorSimilarityFunction.values().length)]; } try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { @@ -628,15 +617,15 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory(); IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) { Document doc1 = new Document(); - doc1.add(new VectorField(fieldName, v, VectorValues.SimilarityFunction.EUCLIDEAN)); + doc1.add(new VectorField(fieldName, v, VectorSimilarityFunction.EUCLIDEAN)); v[0] = 1; Document doc2 = new Document(); - doc2.add(new VectorField(fieldName, v, VectorValues.SimilarityFunction.EUCLIDEAN)); + doc2.add(new VectorField(fieldName, v, VectorSimilarityFunction.EUCLIDEAN)); iw.addDocument(doc1); iw.addDocument(doc2); v[0] = 2; Document doc3 = new Document(); - doc3.add(new VectorField(fieldName, v, VectorValues.SimilarityFunction.EUCLIDEAN)); + doc3.add(new VectorField(fieldName, v, VectorSimilarityFunction.EUCLIDEAN)); iw.addDocument(doc3); iw.forceMerge(1); try (IndexReader reader = iw.getReader()) { @@ -691,16 +680,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); float[] v = new float[] {1}; - doc.add(new VectorField("field1", v, VectorValues.SimilarityFunction.EUCLIDEAN)); - doc.add( - new VectorField("field2", new float[] {1, 2, 3}, VectorValues.SimilarityFunction.NONE)); + doc.add(new VectorField("field1", v, VectorSimilarityFunction.EUCLIDEAN)); + doc.add(new VectorField("field2", new float[] {1, 2, 3}, VectorSimilarityFunction.NONE)); iw.addDocument(doc); v[0] = 2; iw.addDocument(doc); doc = new Document(); doc.add( - new VectorField( - "field3", new float[] {1, 2, 3}, VectorValues.SimilarityFunction.DOT_PRODUCT)); + new VectorField("field3", new float[] {1, 2, 3}, VectorSimilarityFunction.DOT_PRODUCT)); iw.addDocument(doc); iw.forceMerge(1); try (IndexReader reader = iw.getReader()) { @@ -761,9 +748,9 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa if (random().nextBoolean() && values[i] != null) { // sometimes use a shared scratch array System.arraycopy(values[i], 0, scratch, 0, scratch.length); - add(iw, fieldName, i, scratch, VectorValues.SimilarityFunction.NONE); + add(iw, fieldName, i, scratch, VectorSimilarityFunction.NONE); } else { - add(iw, fieldName, i, values[i], VectorValues.SimilarityFunction.NONE); + add(iw, fieldName, i, values[i], VectorSimilarityFunction.NONE); } if (random().nextInt(10) == 2) { // sometimes delete a random document @@ -834,7 +821,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa } id2value[id] = value; id2ord[id] = i; - add(iw, fieldName, id, value, VectorValues.SimilarityFunction.EUCLIDEAN); + add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN); } try (IndexReader reader = iw.getReader()) { for (LeafReaderContext ctx : reader.leaves()) { @@ -871,14 +858,14 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa String field, int id, float[] vector, - VectorValues.SimilarityFunction similarityFunction) + VectorSimilarityFunction similarityFunction) throws IOException { add(iw, field, id, random().nextInt(100), vector, similarityFunction); } private void add(IndexWriter iw, String field, int id, int sortkey, float[] vector) throws IOException { - add(iw, field, id, sortkey, vector, VectorValues.SimilarityFunction.NONE); + add(iw, field, id, sortkey, vector, VectorSimilarityFunction.NONE); } private void add( @@ -887,7 +874,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa int id, int sortkey, float[] vector, - VectorValues.SimilarityFunction similarityFunction) + VectorSimilarityFunction similarityFunction) throws IOException { Document doc = new Document(); if (vector != null) { @@ -913,10 +900,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new VectorField("v1", randomVector(3), VectorValues.SimilarityFunction.NONE)); + doc.add(new VectorField("v1", randomVector(3), VectorSimilarityFunction.NONE)); w.addDocument(doc); - doc.add(new VectorField("v2", randomVector(3), VectorValues.SimilarityFunction.NONE)); + doc.add(new VectorField("v2", randomVector(3), VectorSimilarityFunction.NONE)); w.addDocument(doc); } @@ -937,10 +924,10 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa public void testSimilarityFunctionIdentifiers() { // make sure we don't accidentally mess up similarity function identifiers by re-ordering their // enumerators - assertEquals(0, VectorValues.SimilarityFunction.NONE.ordinal()); - assertEquals(1, VectorValues.SimilarityFunction.EUCLIDEAN.ordinal()); - assertEquals(2, VectorValues.SimilarityFunction.DOT_PRODUCT.ordinal()); - assertEquals(3, VectorValues.SimilarityFunction.values().length); + assertEquals(0, VectorSimilarityFunction.NONE.ordinal()); + assertEquals(1, VectorSimilarityFunction.EUCLIDEAN.ordinal()); + assertEquals(2, VectorSimilarityFunction.DOT_PRODUCT.ordinal()); + assertEquals(3, VectorSimilarityFunction.values().length); } public void testAdvance() throws Exception { @@ -952,7 +939,7 @@ public abstract class BaseVectorFormatTestCase extends BaseIndexFileFormatTestCa Document doc = new Document(); // randomly add a vector field if (random().nextInt(4) == 3) { - doc.add(new VectorField(fieldName, new float[4], VectorValues.SimilarityFunction.NONE)); + doc.add(new VectorField(fieldName, new float[4], VectorSimilarityFunction.NONE)); } w.addDocument(doc); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/index/RandomPostingsTester.java b/lucene/test-framework/src/java/org/apache/lucene/index/RandomPostingsTester.java index d6fcadd2620..45a3434a0e5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/index/RandomPostingsTester.java +++ b/lucene/test-framework/src/java/org/apache/lucene/index/RandomPostingsTester.java @@ -140,7 +140,7 @@ public class RandomPostingsTester { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false); fieldUpto++; @@ -711,7 +711,7 @@ public class RandomPostingsTester { 0, 0, 0, - VectorValues.SimilarityFunction.NONE, + VectorSimilarityFunction.NONE, false); }