mirror of https://github.com/apache/lucene.git
speedup all binary functions on avx256, speedup binary square on avx512
This commit is contained in:
parent
e93bd524cf
commit
3ec9c26d67
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import static jdk.incubator.vector.VectorOperators.ADD;
|
||||
import static jdk.incubator.vector.VectorOperators.B2I;
|
||||
import static jdk.incubator.vector.VectorOperators.B2S;
|
||||
import static jdk.incubator.vector.VectorOperators.S2I;
|
||||
|
||||
|
@ -305,7 +306,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && useIntegerVectors) {
|
||||
// compute vectorized dot product consistent with VPDPBUSD instruction
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
res += dotProductBody512(a, b, i);
|
||||
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
res += dotProductBody256(a, b, i);
|
||||
} else {
|
||||
|
@ -322,14 +326,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
return res;
|
||||
}
|
||||
|
||||
/** vectorized dot product body (256+ bit vectors) */
|
||||
private int dotProductBody256(byte[] a, byte[] b, int limit) {
|
||||
/** vectorized dot product body (512 bit vectors) */
|
||||
private int dotProductBody512(byte[] a, byte[] b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 16-bit multiply
|
||||
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
@ -342,6 +346,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
return acc.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
/** vectorized dot product body (256 bit vectors) */
|
||||
private int dotProductBody256(byte[] a, byte[] b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 32-bit multiply and add into accumulator
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
acc = acc.add(va32.mul(vb32));
|
||||
}
|
||||
// reduce
|
||||
return acc.reduceLanes(ADD);
|
||||
}
|
||||
|
||||
/** vectorized dot product body (128 bit vectors) */
|
||||
private int dotProductBody128(byte[] a, byte[] b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
|
||||
|
@ -374,7 +394,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && useIntegerVectors) {
|
||||
final float[] ret;
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
ret = cosineBody512(a, b, i);
|
||||
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
ret = cosineBody256(a, b, i);
|
||||
} else {
|
||||
|
@ -398,9 +421,8 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
/** vectorized cosine body (256+ bit vectors) */
|
||||
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
|
||||
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
|
||||
/** vectorized cosine body (512 bit vectors) */
|
||||
private float[] cosineBody512(byte[] a, byte[] b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
|
@ -408,20 +430,45 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 16-bit multiply
|
||||
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// sum into accumulators: 32-bit add
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
accSum = accSum.add(prod32);
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
accNorm1 = accNorm1.add(norm1_32);
|
||||
accNorm2 = accNorm2.add(norm2_32);
|
||||
accSum = accSum.add(prod32);
|
||||
}
|
||||
// reduce
|
||||
return new float[] {
|
||||
accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD)
|
||||
};
|
||||
}
|
||||
|
||||
/** vectorized cosine body (256 bit vectors) */
|
||||
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 16-bit multiply, and add into accumulators
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm1_32 = va32.mul(va32);
|
||||
Vector<Integer> norm2_32 = vb32.mul(vb32);
|
||||
Vector<Integer> prod32 = va32.mul(vb32);
|
||||
accNorm1 = accNorm1.add(norm1_32);
|
||||
accNorm2 = accNorm2.add(norm2_32);
|
||||
accSum = accSum.add(prod32);
|
||||
}
|
||||
// reduce
|
||||
return new float[] {
|
||||
|
@ -488,13 +535,11 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 16-bit sub
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> diff16 = va16.sub(vb16);
|
||||
|
||||
// 32-bit multiply and add into accumulators
|
||||
Vector<Integer> diff32 = diff16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
// 32-bit sub, multiply, and add into accumulators
|
||||
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> diff32 = va32.sub(vb32);
|
||||
acc = acc.add(diff32.mul(diff32));
|
||||
}
|
||||
// reduce
|
||||
|
|
Loading…
Reference in New Issue