diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java index 97c05435b96..341e28c36f5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene94/Lucene94FieldInfosFormat.java @@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene94; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.DocValuesFormat; @@ -111,6 +112,8 @@ import org.apache.lucene.store.IndexOutput; *
  • 0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN}) *
  • 1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT}) *
  • 2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE}) + *
  • 3: MAXIMUM_INNER_PRODUCT similarity. ({@link + * VectorSimilarityFunction#MAXIMUM_INNER_PRODUCT}) * * * @@ -284,10 +287,38 @@ public final class Lucene94FieldInfosFormat extends FieldInfosFormat { } private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException { - if (b < 0 || b >= VectorSimilarityFunction.values().length) { - throw new CorruptIndexException("invalid distance function: " + b, input); + try { + return distOrdToFunc(b); + } catch (IllegalArgumentException e) { + throw new CorruptIndexException("invalid distance function: " + b, input, e); } - return VectorSimilarityFunction.values()[b]; + } + + // List of vector similarity functions. This list is defined here, in order + // to avoid an undesirable dependency on the declaration and order of values + // in VectorSimilarityFunction. The list values and order have been chosen to + // match that of VectorSimilarityFunction in, at least, Lucene 9.10. Values + static final List SIMILARITY_FUNCTIONS = + List.of( + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.COSINE, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + + static VectorSimilarityFunction distOrdToFunc(byte i) { + if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { + throw new IllegalArgumentException("invalid distance function: " + i); + } + return SIMILARITY_FUNCTIONS.get(i); + } + + static byte distFuncToOrd(VectorSimilarityFunction func) { + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { + return (byte) i; + } + } + throw new IllegalArgumentException("invalid distance function: " + func); } static { @@ -378,7 +409,7 @@ public final class Lucene94FieldInfosFormat extends FieldInfosFormat { } output.writeVInt(fi.getVectorDimension()); output.writeByte((byte) fi.getVectorEncoding().ordinal()); - output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal()); + output.writeByte(distFuncToOrd(fi.getVectorSimilarityFunction())); } CodecUtil.writeFooter(output); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 9ebac62ce9b..efb51c963e0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -22,6 +22,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.FlatVectorsReader; @@ -171,15 +172,24 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader } } + // List of vector similarity functions. This list is defined here, in order + // to avoid an undesirable dependency on the declaration and order of values + // in VectorSimilarityFunction. The list values and order must be identical + // to that of {@link o.a.l.c.l.Lucene94FieldInfosFormat#SIMILARITY_FUNCTIONS}. + public static final List SIMILARITY_FUNCTIONS = + List.of( + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.COSINE, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT); + public static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { - int similarityFunctionId = input.readInt(); - if (similarityFunctionId < 0 - || similarityFunctionId >= VectorSimilarityFunction.values().length) { - throw new CorruptIndexException( - "Invalid similarity function id: " + similarityFunctionId, input); + int i = input.readInt(); + if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { + throw new IllegalArgumentException("invalid distance function: " + i); } - return VectorSimilarityFunction.values()[similarityFunctionId]; + return SIMILARITY_FUNCTIONS.get(i); } public static VectorEncoding readVectorEncoding(DataInput input) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index 174c65db9ac..a236dd7c65b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene99; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; import java.io.IOException; import java.util.ArrayList; @@ -33,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.store.IndexOutput; @@ -436,7 +438,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { throws IOException { meta.writeInt(field.number); meta.writeInt(field.getVectorEncoding().ordinal()); - meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); meta.writeVLong(vectorIndexOffset); meta.writeVLong(vectorIndexLength); meta.writeVInt(field.getVectorDimension()); @@ -500,6 +502,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { IOUtils.close(meta, vectorIndex, flatVectorWriter); } + static int distFuncToOrd(VectorSimilarityFunction func) { + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { + return (byte) i; + } + } + throw new IllegalArgumentException("invalid distance function: " + func); + } + private static class FieldWriter extends KnnFieldVectorsWriter { private static final long SHALLOW_SIZE = diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene94/TestLucene94FieldInfosFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene94/TestLucene94FieldInfosFormat.java new file mode 100644 index 00000000000..c69eeadf5e6 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene94/TestLucene94FieldInfosFormat.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene94; + +import java.util.Arrays; +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.index.BaseFieldInfoFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; + +public class TestLucene94FieldInfosFormat extends BaseFieldInfoFormatTestCase { + @Override + protected Codec getCodec() { + return TestUtil.getDefaultCodec(); + } + + // Ensures that all expected vector similarity functions are translatable + // in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + + assertEquals(Lucene94FieldInfosFormat.SIMILARITY_FUNCTIONS, expectedValues); + } +} diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index be0b01f3e0b..382389bc8f3 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; @@ -186,4 +187,13 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat new Lucene99HnswScalarQuantizedVectorsFormat( 20, 100, 1, null, new SameThreadExecutorService())); } + + // Ensures that all expected vector similarity functions are translatable + // in the format. + public void testVectorSimilarityFuncs() { + // This does not necessarily have to be all similarity functions, but + // differences should be considered carefully. + var expectedValues = Arrays.stream(VectorSimilarityFunction.values()).toList(); + assertEquals(Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS, expectedValues); + } }