mirror of https://github.com/apache/lucene.git
add tests for int4BitDotProdut
This commit is contained in:
parent
62cd45ba54
commit
db10587ba1
|
@ -21,8 +21,8 @@ import org.junit.BeforeClass;
|
|||
|
||||
public abstract class BaseVectorizationTestCase extends LuceneTestCase {
|
||||
|
||||
protected static final VectorizationProvider LUCENE_PROVIDER = new DefaultVectorizationProvider();
|
||||
protected static final VectorizationProvider PANAMA_PROVIDER = VectorizationProvider.lookup(true);
|
||||
protected static final VectorizationProvider LUCENE_PROVIDER = defaultProvider();
|
||||
protected static final VectorizationProvider PANAMA_PROVIDER = maybePanamaProvider();
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
|
@ -30,4 +30,12 @@ public abstract class BaseVectorizationTestCase extends LuceneTestCase {
|
|||
"Test only works when JDK's vector incubator module is enabled.",
|
||||
PANAMA_PROVIDER.getClass() != LUCENE_PROVIDER.getClass());
|
||||
}
|
||||
|
||||
public static VectorizationProvider defaultProvider() {
|
||||
return new DefaultVectorizationProvider();
|
||||
}
|
||||
|
||||
public static VectorizationProvider maybePanamaProvider() {
|
||||
return VectorizationProvider.lookup(true);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,13 @@
|
|||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.generators.RandomNumbers.randomIntBetween;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.internal.vectorization.BaseVectorizationTestCase;
|
||||
import org.apache.lucene.internal.vectorization.VectorizationProvider;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
||||
|
@ -384,4 +389,128 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
}
|
||||
return length;
|
||||
}
|
||||
|
||||
public void testInt4BitDotProductInvariants() {
|
||||
int iterations = atLeast(10);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int size = randomIntBetween(random(), 1, 10);
|
||||
var d = new byte[size];
|
||||
var q = new byte[size * 4 - 1];
|
||||
expectThrows(IllegalArgumentException.class, () -> VectorUtil.int4BitDotProduct(q, d));
|
||||
}
|
||||
}
|
||||
|
||||
static final VectorizationProvider defaultedProvider =
|
||||
BaseVectorizationTestCase.defaultProvider();
|
||||
static final VectorizationProvider defOrPanamaProvider =
|
||||
BaseVectorizationTestCase.maybePanamaProvider();
|
||||
|
||||
public void testBasicInt4BitDotProduct() {
|
||||
testBasicInt4BitDotProductImpl(VectorUtil::int4BitDotProduct);
|
||||
testBasicInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct);
|
||||
testBasicInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct);
|
||||
}
|
||||
|
||||
interface Int4BitDotProduct {
|
||||
long apply(byte[] q, byte[] d);
|
||||
}
|
||||
|
||||
void testBasicInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) {
|
||||
assertEquals(15L, Int4BitDotProductFunc.apply(new byte[] {1, 1, 1, 1}, new byte[] {1}));
|
||||
assertEquals(
|
||||
30L, Int4BitDotProductFunc.apply(new byte[] {1, 2, 1, 2, 1, 2, 1, 2}, new byte[] {1, 2}));
|
||||
|
||||
var d = new byte[] {1, 2, 3};
|
||||
var q = new byte[] {1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3};
|
||||
assert scalarInt4BitDotProduct(q, d) == 60L; // 4 + 8 + 16 + 32
|
||||
assertEquals(60L, Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
d = new byte[] {1, 2, 3, 4};
|
||||
q = new byte[] {1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4};
|
||||
assert scalarInt4BitDotProduct(q, d) == 75L; // 5 + 10 + 20 + 40
|
||||
assertEquals(75L, Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
d = new byte[] {1, 2, 3, 4, 5};
|
||||
q = new byte[] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
|
||||
assert scalarInt4BitDotProduct(q, d) == 105L; // 7 + 14 + 28 + 56
|
||||
assertEquals(105L, Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
d = new byte[] {1, 2, 3, 4, 5, 6};
|
||||
q = new byte[] {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6};
|
||||
assert scalarInt4BitDotProduct(q, d) == 135L; // 9 + 18 + 36 + 72
|
||||
assertEquals(135L, Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
d = new byte[] {1, 2, 3, 4, 5, 6, 7};
|
||||
q =
|
||||
new byte[] {
|
||||
1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7
|
||||
};
|
||||
assert scalarInt4BitDotProduct(q, d) == 180L; // 12 + 24 + 48 + 96
|
||||
assertEquals(180L, Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
q =
|
||||
new byte[] {
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6,
|
||||
7, 8
|
||||
};
|
||||
assert scalarInt4BitDotProduct(q, d) == 195L; // 13 + 26 + 52 + 104
|
||||
assertEquals(195L, Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
d = new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9};
|
||||
q =
|
||||
new byte[] {
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3,
|
||||
4, 5, 6, 7, 8, 9
|
||||
};
|
||||
assert scalarInt4BitDotProduct(q, d) == 225L; // 15 + 30 + 60 + 120
|
||||
assertEquals(225L, Int4BitDotProductFunc.apply(q, d));
|
||||
}
|
||||
|
||||
public void testInt4BitDotProduct() {
|
||||
testInt4BitDotProductImpl(VectorUtil::int4BitDotProduct);
|
||||
testInt4BitDotProductImpl(defaultedProvider.getVectorUtilSupport()::int4BitDotProduct);
|
||||
testInt4BitDotProductImpl(defOrPanamaProvider.getVectorUtilSupport()::int4BitDotProduct);
|
||||
}
|
||||
|
||||
void testInt4BitDotProductImpl(Int4BitDotProduct Int4BitDotProductFunc) {
|
||||
int iterations = atLeast(50);
|
||||
for (int i = 0; i < iterations; i++) {
|
||||
int size = random().nextInt(5000);
|
||||
var d = new byte[size];
|
||||
var q = new byte[size * 4];
|
||||
random().nextBytes(d);
|
||||
random().nextBytes(q);
|
||||
assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
Arrays.fill(d, Byte.MAX_VALUE);
|
||||
Arrays.fill(q, Byte.MAX_VALUE);
|
||||
assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
|
||||
|
||||
Arrays.fill(d, Byte.MIN_VALUE);
|
||||
Arrays.fill(q, Byte.MIN_VALUE);
|
||||
assertEquals(scalarInt4BitDotProduct(q, d), Int4BitDotProductFunc.apply(q, d));
|
||||
}
|
||||
}
|
||||
|
||||
static int scalarInt4BitDotProduct(byte[] q, byte[] d) {
|
||||
int res = 0;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res += (popcount(q, i * d.length, d, d.length) << i);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
public static int popcount(byte[] a, int aOffset, byte[] b, int length) {
|
||||
int res = 0;
|
||||
for (int j = 0; j < length; j++) {
|
||||
int value = (a[aOffset + j] & b[j]) & 0xFF;
|
||||
for (int k = 0; k < Byte.SIZE; k++) {
|
||||
if ((value & (1 << k)) != 0) {
|
||||
++res;
|
||||
}
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue