speedup all binary functions on avx256, speedup binary square on avx512

This commit is contained in:
Robert Muir 2023-10-14 16:28:26 -04:00
parent e93bd524cf
commit 3ec9c26d67
No known key found for this signature in database
GPG Key ID: 817AE1DD322D7ECA
1 changed files with 64 additions and 19 deletions

View File

@ -17,6 +17,7 @@
package org.apache.lucene.internal.vectorization;
import static jdk.incubator.vector.VectorOperators.ADD;
import static jdk.incubator.vector.VectorOperators.B2I;
import static jdk.incubator.vector.VectorOperators.B2S;
import static jdk.incubator.vector.VectorOperators.S2I;
@ -305,7 +306,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && useIntegerVectors) {
// compute vectorized dot product consistent with VPDPBUSD instruction
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.length);
res += dotProductBody512(a, b, i);
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
i += BYTE_SPECIES.loopBound(a.length);
res += dotProductBody256(a, b, i);
} else {
@ -322,14 +326,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
return res;
}
/** vectorized dot product body (256+ bit vectors) */
private int dotProductBody256(byte[] a, byte[] b, int limit) {
/** vectorized dot product body (512 bit vectors) */
private int dotProductBody512(byte[] a, byte[] b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 16-bit multiply
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> prod16 = va16.mul(vb16);
@ -342,6 +346,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
return acc.reduceLanes(ADD);
}
/** vectorized dot product body (256 bit vectors) */
private int dotProductBody256(byte[] a, byte[] b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 32-bit multiply and add into accumulator
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
acc = acc.add(va32.mul(vb32));
}
// reduce
return acc.reduceLanes(ADD);
}
/** vectorized dot product body (128 bit vectors) */
private int dotProductBody128(byte[] a, byte[] b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
@ -374,7 +394,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && useIntegerVectors) {
final float[] ret;
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.length);
ret = cosineBody512(a, b, i);
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
i += BYTE_SPECIES.loopBound(a.length);
ret = cosineBody256(a, b, i);
} else {
@ -398,9 +421,8 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
/** vectorized cosine body (256+ bit vectors) */
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
/** vectorized cosine body (512 bit vectors) */
private float[] cosineBody512(byte[] a, byte[] b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
@ -408,20 +430,45 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 16-bit multiply
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> prod16 = va16.mul(vb16);
Vector<Short> norm1_16 = va16.mul(va16);
Vector<Short> norm2_16 = vb16.mul(vb16);
Vector<Short> prod16 = va16.mul(vb16);
// sum into accumulators: 32-bit add
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
accSum = accSum.add(prod32);
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
accNorm1 = accNorm1.add(norm1_32);
accNorm2 = accNorm2.add(norm2_32);
accSum = accSum.add(prod32);
}
// reduce
return new float[] {
accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD)
};
}
/** vectorized cosine body (256 bit vectors) */
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 16-bit multiply, and add into accumulators
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm1_32 = va32.mul(va32);
Vector<Integer> norm2_32 = vb32.mul(vb32);
Vector<Integer> prod32 = va32.mul(vb32);
accNorm1 = accNorm1.add(norm1_32);
accNorm2 = accNorm2.add(norm2_32);
accSum = accSum.add(prod32);
}
// reduce
return new float[] {
@ -488,13 +535,11 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 16-bit sub
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> diff16 = va16.sub(vb16);
// 32-bit multiply and add into accumulators
Vector<Integer> diff32 = diff16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
// 32-bit sub, multiply, and add into accumulators
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> diff32 = va32.sub(vb32);
acc = acc.add(diff32.mul(diff32));
}
// reduce