diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 440d4a65aa2..9fe5ddd0e2b 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -17,6 +17,7 @@ package org.apache.lucene.internal.vectorization; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import java.util.Arrays; import java.util.function.ToDoubleFunction; import java.util.function.ToIntFunction; import java.util.stream.IntStream; @@ -62,6 +63,35 @@ public class TestVectorUtilSupport extends BaseVectorizationTestCase { assertFloatReturningProviders(p -> p.cosine(a, b)); } + public void testBinaryVectorsBoundaries() { + var a = new byte[size]; + var b = new byte[size]; + + Arrays.fill(a, Byte.MIN_VALUE); + Arrays.fill(b, Byte.MIN_VALUE); + assertIntReturningProviders(p -> p.dotProduct(a, b)); + assertIntReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + + Arrays.fill(a, Byte.MAX_VALUE); + Arrays.fill(b, Byte.MAX_VALUE); + assertIntReturningProviders(p -> p.dotProduct(a, b)); + assertIntReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + + Arrays.fill(a, Byte.MIN_VALUE); + Arrays.fill(b, Byte.MAX_VALUE); + assertIntReturningProviders(p -> p.dotProduct(a, b)); + assertIntReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + + Arrays.fill(a, Byte.MAX_VALUE); + Arrays.fill(b, Byte.MIN_VALUE); + assertIntReturningProviders(p -> p.dotProduct(a, b)); + assertIntReturningProviders(p -> p.squareDistance(a, b)); + assertFloatReturningProviders(p -> p.cosine(a, b)); + } + private void assertFloatReturningProviders(ToDoubleFunction func) { assertEquals( func.applyAsDouble(LUCENE_PROVIDER.getVectorUtilSupport()),