From 62cd45ba541c9b9b0569d877ca67f994dee5406b Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Wed, 18 Dec 2024 15:32:21 +0000 Subject: [PATCH] test default and panama impls return the same result --- .../vectorization/TestVectorUtilSupport.java | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java index 7064955cb5f..6443de752bf 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorUtilSupport.java @@ -20,6 +20,7 @@ import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; import java.util.Arrays; import java.util.function.ToDoubleFunction; import java.util.function.ToIntFunction; +import java.util.function.ToLongFunction; import java.util.stream.IntStream; public class TestVectorUtilSupport extends BaseVectorizationTestCase { @@ -133,6 +134,27 @@ public class TestVectorUtilSupport extends BaseVectorizationTestCase { PANAMA_PROVIDER.getVectorUtilSupport().int4DotProduct(a, false, pack(b), true)); } + public void testInt4BitDotProduct() { + var binaryQuantized = new byte[size]; + var int4Quantized = new byte[size * 4]; + random().nextBytes(binaryQuantized); + random().nextBytes(int4Quantized); + assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized)); + } + + public void testInt4BitDotProductBoundaries() { + var binaryQuantized = new byte[size]; + var int4Quantized = new byte[size * 4]; + + Arrays.fill(binaryQuantized, Byte.MAX_VALUE); + Arrays.fill(int4Quantized, Byte.MAX_VALUE); + assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized)); + + Arrays.fill(binaryQuantized, Byte.MIN_VALUE); + Arrays.fill(int4Quantized, Byte.MIN_VALUE); + assertLongReturningProviders(p -> p.int4BitDotProduct(int4Quantized, binaryQuantized)); + } + static byte[] pack(byte[] unpacked) { int len = (unpacked.length + 1) / 2; var packed = new byte[len]; @@ -154,4 +176,10 @@ public class TestVectorUtilSupport extends BaseVectorizationTestCase { func.applyAsInt(LUCENE_PROVIDER.getVectorUtilSupport()), func.applyAsInt(PANAMA_PROVIDER.getVectorUtilSupport())); } + + private void assertLongReturningProviders(ToLongFunction func) { + assertEquals( + func.applyAsLong(LUCENE_PROVIDER.getVectorUtilSupport()), + func.applyAsLong(PANAMA_PROVIDER.getVectorUtilSupport())); + } }