diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 9e58b4fea57..782675ce4d7 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -148,6 +148,9 @@ New Features search results can be provided. The first custom collector provides `ToParentBlockJoin[Float|Byte]KnnVectorQuery` joining child vector documents with their parent documents. (Ben Trent) +* GITHUB#12479: Add new Maximum Inner Product vector similarity function for non-normalized dot-product + vector search. (Jack Mazanec, Ben Trent) + Improvements --------------------- * GITHUB#12374: Add CachingLeafSlicesSupplier to compute the LeafSlices for concurrent segment search (Sorabh Hamirwasia) diff --git a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java index 8a515cb79fc..ae0633e8c0f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java +++ b/lucene/core/src/java/org/apache/lucene/index/VectorSimilarityFunction.java @@ -19,6 +19,7 @@ package org.apache.lucene.index; import static org.apache.lucene.util.VectorUtil.cosine; import static org.apache.lucene.util.VectorUtil.dotProduct; import static org.apache.lucene.util.VectorUtil.dotProductScore; +import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore; import static org.apache.lucene.util.VectorUtil.squareDistance; /** @@ -76,6 +77,23 @@ public enum VectorSimilarityFunction { public float compare(byte[] v1, byte[] v2) { return (1 + cosine(v1, v2)) / 2; } + }, + + /** + * Maximum inner product. This is like {@link VectorSimilarityFunction#DOT_PRODUCT}, but does not + * require normalization of the inputs. Should be used when the embedding vectors store useful + * information within the vector magnitude + */ + MAXIMUM_INNER_PRODUCT { + @Override + public float compare(float[] v1, float[] v2) { + return scaleMaxInnerProductScore(dotProduct(v1, v2)); + } + + @Override + public float compare(byte[] v1, byte[] v2) { + return scaleMaxInnerProductScore(dotProduct(v1, v2)); + } }; /** diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index b8819082ba9..1af99245806 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -164,6 +164,17 @@ public final class VectorUtil { return 0.5f + dotProduct(a, b) / denom; } + /** + * @param vectorDotProductSimilarity the raw similarity between two vectors + * @return A scaled score preventing negative scores for maximum-inner-product + */ + public static float scaleMaxInnerProductScore(float vectorDotProductSimilarity) { + if (vectorDotProductSimilarity < 0) { + return 1 / (1 + -1 * vectorDotProductSimilarity); + } + return vectorDotProductSimilarity + 1; + } + /** * Checks if a float vector only has finite components. * diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index d54db77c739..e2f47865051 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -346,6 +346,29 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { } } + public void testScoreMIP() throws IOException { + try (Directory indexStore = + getIndexStore( + "field", + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, + new float[] {0, 1}, + new float[] {1, 2}, + new float[] {0, 0}); + IndexReader reader = DirectoryReader.open(indexStore)) { + IndexSearcher searcher = newSearcher(reader); + AbstractKnnVectorQuery kvq = getKnnVectorQuery("field", new float[] {0, -1}, 10); + assertMatches(searcher, kvq, 3); + ScoreDoc[] scoreDocs = searcher.search(kvq, 3).scoreDocs; + assertIdMatches(reader, "id2", scoreDocs[0]); + assertIdMatches(reader, "id0", scoreDocs[1]); + assertIdMatches(reader, "id1", scoreDocs[2]); + + assertEquals(1.0, scoreDocs[0].score, 1e-7); + assertEquals(1 / 2f, scoreDocs[1].score, 1e-7); + assertEquals(1 / 3f, scoreDocs[2].score, 1e-7); + } + } + public void testExplain() throws IOException { try (Directory d = newDirectory()) { try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) { @@ -739,11 +762,21 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { /** Creates a new directory and adds documents with the given vectors as kNN vector fields */ Directory getIndexStore(String field, float[]... contents) throws IOException { + return getIndexStore(field, VectorSimilarityFunction.EUCLIDEAN, contents); + } + + /** + * Creates a new directory and adds documents with the given vectors with similarity as kNN vector + * fields + */ + Directory getIndexStore( + String field, VectorSimilarityFunction vectorSimilarityFunction, float[]... contents) + throws IOException { Directory indexStore = newDirectory(); RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore); for (int i = 0; i < contents.length; ++i) { Document doc = new Document(); - doc.add(getKnnVectorField(field, contents[i])); + doc.add(getKnnVectorField(field, contents[i], vectorSimilarityFunction)); doc.add(new StringField("id", "id" + i, Field.Store.YES)); writer.addDocument(doc); } diff --git a/lucene/luke/src/java/org/apache/lucene/luke/app/desktop/components/DocumentsPanelProvider.java b/lucene/luke/src/java/org/apache/lucene/luke/app/desktop/components/DocumentsPanelProvider.java index 613cca415eb..56774640848 100644 --- a/lucene/luke/src/java/org/apache/lucene/luke/app/desktop/components/DocumentsPanelProvider.java +++ b/lucene/luke/src/java/org/apache/lucene/luke/app/desktop/components/DocumentsPanelProvider.java @@ -1246,6 +1246,9 @@ public final class DocumentsPanelProvider implements DocumentsTabOperator { case EUCLIDEAN: sb.append("euc"); break; + case MAXIMUM_INNER_PRODUCT: + sb.append("mip"); + break; default: sb.append("???"); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 7f5dce9aa5e..4167e7a8a38 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -1278,7 +1278,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe assertEquals(0, VectorSimilarityFunction.EUCLIDEAN.ordinal()); assertEquals(1, VectorSimilarityFunction.DOT_PRODUCT.ordinal()); assertEquals(2, VectorSimilarityFunction.COSINE.ordinal()); - assertEquals(3, VectorSimilarityFunction.values().length); + assertEquals(3, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT.ordinal()); + assertEquals(4, VectorSimilarityFunction.values().length); } public void testVectorEncodingOrdinals() {