diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java
index 97c05435b96..341e28c36f5 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java
@@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene94;
import java.io.IOException;
import java.util.Collections;
+import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.DocValuesFormat;
@@ -111,6 +112,8 @@ import org.apache.lucene.store.IndexOutput;
*
0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* 1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* 2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
+ * 3: MAXIMUM_INNER_PRODUCT similarity. ({@link
+ * VectorSimilarityFunction#MAXIMUM_INNER_PRODUCT})
*
*
*
@@ -284,10 +287,38 @@ public final class Lucene94FieldInfosFormat extends FieldInfosFormat {
}
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
- if (b < 0 || b >= VectorSimilarityFunction.values().length) {
- throw new CorruptIndexException("invalid distance function: " + b, input);
+ try {
+ return distOrdToFunc(b);
+ } catch (IllegalArgumentException e) {
+ throw new CorruptIndexException("invalid distance function: " + b, input, e);
}
- return VectorSimilarityFunction.values()[b];
+ }
+
+ // List of vector similarity functions. This list is defined here, in order
+ // to avoid an undesirable dependency on the declaration and order of values
+ // in VectorSimilarityFunction. The list values and order have been chosen to
+ // match that of VectorSimilarityFunction in, at least, Lucene 9.10. Values
+ static final List SIMILARITY_FUNCTIONS =
+ List.of(
+ VectorSimilarityFunction.EUCLIDEAN,
+ VectorSimilarityFunction.DOT_PRODUCT,
+ VectorSimilarityFunction.COSINE,
+ VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
+
+ static VectorSimilarityFunction distOrdToFunc(byte i) {
+ if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
+ throw new IllegalArgumentException("invalid distance function: " + i);
+ }
+ return SIMILARITY_FUNCTIONS.get(i);
+ }
+
+ static byte distFuncToOrd(VectorSimilarityFunction func) {
+ for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
+ if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
+ return (byte) i;
+ }
+ }
+ throw new IllegalArgumentException("invalid distance function: " + func);
}
static {
@@ -378,7 +409,7 @@ public final class Lucene94FieldInfosFormat extends FieldInfosFormat {
}
output.writeVInt(fi.getVectorDimension());
output.writeByte((byte) fi.getVectorEncoding().ordinal());
- output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal());
+ output.writeByte(distFuncToOrd(fi.getVectorSimilarityFunction()));
}
CodecUtil.writeFooter(output);
}
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java
index 9ebac62ce9b..efb51c963e0 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java
@@ -22,6 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.FlatVectorsReader;
@@ -171,15 +172,24 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
}
}
+ // List of vector similarity functions. This list is defined here, in order
+ // to avoid an undesirable dependency on the declaration and order of values
+ // in VectorSimilarityFunction. The list values and order must be identical
+ // to that of {@link o.a.l.c.l.Lucene94FieldInfosFormat#SIMILARITY_FUNCTIONS}.
+ public static final List SIMILARITY_FUNCTIONS =
+ List.of(
+ VectorSimilarityFunction.EUCLIDEAN,
+ VectorSimilarityFunction.DOT_PRODUCT,
+ VectorSimilarityFunction.COSINE,
+ VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
+
public static VectorSimilarityFunction readSimilarityFunction(DataInput input)
throws IOException {
- int similarityFunctionId = input.readInt();
- if (similarityFunctionId < 0
- || similarityFunctionId >= VectorSimilarityFunction.values().length) {
- throw new CorruptIndexException(
- "Invalid similarity function id: " + similarityFunctionId, input);
+ int i = input.readInt();
+ if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) {
+ throw new IllegalArgumentException("invalid distance function: " + i);
}
- return VectorSimilarityFunction.values()[similarityFunctionId];
+ return SIMILARITY_FUNCTIONS.get(i);
}
public static VectorEncoding readVectorEncoding(DataInput input) throws IOException {
diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
index 174c65db9ac..a236dd7c65b 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java
@@ -18,6 +18,7 @@
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
+import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS;
import java.io.IOException;
import java.util.ArrayList;
@@ -33,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
+import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.IndexOutput;
@@ -436,7 +438,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
- meta.writeInt(field.getVectorSimilarityFunction().ordinal());
+ meta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction()));
meta.writeVLong(vectorIndexOffset);
meta.writeVLong(vectorIndexLength);
meta.writeVInt(field.getVectorDimension());
@@ -500,6 +502,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
IOUtils.close(meta, vectorIndex, flatVectorWriter);
}
+ static int distFuncToOrd(VectorSimilarityFunction func) {
+ for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) {
+ if (SIMILARITY_FUNCTIONS.get(i).equals(func)) {
+ return (byte) i;
+ }
+ }
+ throw new IllegalArgumentException("invalid distance function: " + func);
+ }
+
private static class FieldWriter extends KnnFieldVectorsWriter {
private static final long SHALLOW_SIZE =
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene94/TestLucene94FieldInfosFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene94/TestLucene94FieldInfosFormat.java
new file mode 100644
index 00000000000..c69eeadf5e6
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene94/TestLucene94FieldInfosFormat.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.codecs.lucene94;
+
+import java.util.Arrays;
+import org.apache.lucene.codecs.Codec;
+import org.apache.lucene.index.VectorSimilarityFunction;
+import org.apache.lucene.tests.index.BaseFieldInfoFormatTestCase;
+import org.apache.lucene.tests.util.TestUtil;
+
+public class TestLucene94FieldInfosFormat extends BaseFieldInfoFormatTestCase {
+ @Override
+ protected Codec getCodec() {
+ return TestUtil.getDefaultCodec();
+ }
+
+ // Ensures that all expected vector similarity functions are translatable
+ // in the format.
+ public void testVectorSimilarityFuncs() {
+ // This does not necessarily have to be all similarity functions, but
+ // differences should be considered carefully.
+ var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
+
+ assertEquals(Lucene94FieldInfosFormat.SIMILARITY_FUNCTIONS, expectedValues);
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
index be0b01f3e0b..382389bc8f3 100644
--- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
+++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java
@@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
@@ -186,4 +187,13 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
new Lucene99HnswScalarQuantizedVectorsFormat(
20, 100, 1, null, new SameThreadExecutorService()));
}
+
+ // Ensures that all expected vector similarity functions are translatable
+ // in the format.
+ public void testVectorSimilarityFuncs() {
+ // This does not necessarily have to be all similarity functions, but
+ // differences should be considered carefully.
+ var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList();
+ assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues);
+ }
}