add tests for int4BitDotProdut

This commit is contained in:
ChrisHegarty 2024-12-18 15:43:11 +00:00
parent 62cd45ba54
commit db10587ba1
2 changed files with 139 additions and 2 deletions

View File

@ -21,8 +21,8 @@ import org.junit.BeforeClass;
public abstract class BaseVectorizationTestCase extends LuceneTestCase { public abstract class BaseVectorizationTestCase extends LuceneTestCase {
protected static final VectorizationProvider LUCENE_PROVIDER = new DefaultVectorizationProvider(); protected static final VectorizationProvider LUCENE_PROVIDER = defaultProvider();
protected static final VectorizationProvider PANAMA_PROVIDER = VectorizationProvider.lookup(true); protected static final VectorizationProvider PANAMA_PROVIDER = maybePanamaProvider();
@BeforeClass @BeforeClass
public static void beforeClass() throws Exception { public static void beforeClass() throws Exception {
@ -30,4 +30,12 @@ public abstract class BaseVectorizationTestCase extends LuceneTestCase {
"Test only works when JDK's vector incubator module is enabled.", "Test only works when JDK's vector incubator module is enabled.",
PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass()); PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass());
} }
public static VectorizationProvider defaultProvider() {
return new DefaultVectorizationProvider();
}
public static VectorizationProvider maybePanamaProvider() {
return VectorizationProvider.lookup(true);
}
} }

View File

@ -16,8 +16,13 @@
*/ */
package org.apache.lucene.util; package org.apache.lucene.util;
import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween;
import java.util.Arrays;
import java.util.Random; import java.util.Random;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
@ -384,4 +389,128 @@ public class TestVectorUtil extends LuceneTestCase {
} }
return length; return length;
} }
public void testInt4BitDotProductInvariants() {
int iterations = atLeast(10);
for (int i = 0; i < iterations; i++) {
int size = randomIntBetween(random(), 1, 10);
var d = new byte[size];
var q = new byte[size * 4 - 1];
expectThrows(IllegalArgumentException.class, () -> VectorUtil.int4BitDotProduct(q, d));
}
}
static final VectorizationProvider defaultedProvider =
BaseVectorizationTestCase.defaultProvider();
static final VectorizationProvider defOrPanamaProvider =
BaseVectorizationTestCase.maybePanamaProvider();
public void testBasicInt4BitDotProduct() {
testBasicInt4BitDotProductImpl(VectorUtil::int4BitDotProduct);
testBasicInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct);
testBasicInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct);
}
interface Int4BitDotProduct {
long apply(byte[] q, byte[] d);
}
void testBasicInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) {
assertEquals(15L, Int4BitDotProductFunc.apply(new byte[] {1, 1, 1, 1}, new byte[] {1}));
assertEquals(
30L, Int4BitDotProductFunc.apply(new byte[] {1, 2, 1, 2, 1, 2, 1, 2}, new byte[] {1, 2}));
var d = new byte[] {1, 2, 3};
var q = new byte[] {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
assert scalarInt4BitDotProduct(q, d) == 60L; // 4 + 8 + 16 + 32
assertEquals(60L, Int4BitDotProductFunc.apply(q, d));
d = new byte[] {1, 2, 3, 4};
q = new byte[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4};
assert scalarInt4BitDotProduct(q, d) == 75L; // 5 + 10 + 20 + 40
assertEquals(75L, Int4BitDotProductFunc.apply(q, d));
d = new byte[] {1, 2, 3, 4, 5};
q = new byte[] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
assert scalarInt4BitDotProduct(q, d) == 105L; // 7 + 14 + 28 + 56
assertEquals(105L, Int4BitDotProductFunc.apply(q, d));
d = new byte[] {1, 2, 3, 4, 5, 6};
q = new byte[] {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
assert scalarInt4BitDotProduct(q, d) == 135L; // 9 + 18 + 36 + 72
assertEquals(135L, Int4BitDotProductFunc.apply(q, d));
d = new byte[] {1, 2, 3, 4, 5, 6, 7};
q =
new byte[] {
1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7
};
assert scalarInt4BitDotProduct(q, d) == 180L; // 12 + 24 + 48 + 96
assertEquals(180L, Int4BitDotProductFunc.apply(q, d));
d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8};
q =
new byte[] {
1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
7, 8
};
assert scalarInt4BitDotProduct(q, d) == 195L; // 13 + 26 + 52 + 104
assertEquals(195L, Int4BitDotProductFunc.apply(q, d));
d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
q =
new byte[] {
1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3,
4, 5, 6, 7, 8, 9
};
assert scalarInt4BitDotProduct(q, d) == 225L; // 15 + 30 + 60 + 120
assertEquals(225L, Int4BitDotProductFunc.apply(q, d));
}
public void testInt4BitDotProduct() {
testInt4BitDotProductImpl(VectorUtil::int4BitDotProduct);
testInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct);
testInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct);
}
void testInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) {
int iterations = atLeast(50);
for (int i = 0; i < iterations; i++) {
int size = random().nextInt(5000);
var d = new byte[size];
var q = new byte[size * 4];
random().nextBytes(d);
random().nextBytes(q);
assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
Arrays.fill(d, Byte.MAX_VALUE);
Arrays.fill(q, Byte.MAX_VALUE);
assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
Arrays.fill(d, Byte.MIN_VALUE);
Arrays.fill(q, Byte.MIN_VALUE);
assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
}
}
static int scalarInt4BitDotProduct(byte[] q, byte[] d) {
int res = 0;
for (int i = 0; i < 4; i++) {
res += (popcount(q, i * d.length, d, d.length) << i);
}
return res;
}
public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
int res = 0;
for (int j = 0; j < length; j++) {
int value = (a[aOffset + j] & b[j]) & 0xFF;
for (int k = 0; k < Byte.SIZE; k++) {
if ((value & (1 << k)) != 0) {
++res;
}
}
}
return res;
}
} }