Speedup integer functions for 128-bit neon vectors (#12632)

This commit is contained in:
Robert Muir 2023-10-14 11:38:33 -04:00 committed by GitHub
parent a4ff129de8
commit 872aee6d18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 34 additions and 51 deletions

View File

@ -279,27 +279,24 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// reduce
res += acc.reduceLanes(VectorOperators.ADD);
} else {
// 128-bit implementation, which must "split up" vectors due to widening conversions
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
// 128-bit impl, which is tricky since we don't have SPECIES_32, it does "overlapping read"
int upperBound = ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
// 4 bytes at a time
for (; i < upperBound; i += ByteVector.SPECIES_64.length() >> 1) {
// load 8 bytes
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
// expand each byte vector into short vector and multiply
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
// process first "half" only
Vector<Short> va16 = va8.convert(VectorOperators.B2S, 0);
Vector<Short> vb16 = vb8.convert(VectorOperators.B2S, 0);
Vector<Short> prod16 = va16.mul(vb16);
// split each short vector into two int vectors and add
Vector<Integer> prod32_1 =
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
Vector<Integer> prod32_2 =
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
acc1 = acc1.add(prod32_1);
acc2 = acc2.add(prod32_2);
acc = acc.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
}
// reduce
res += acc1.add(acc2).reduceLanes(VectorOperators.ADD);
res += acc.reduceLanes(VectorOperators.ADD);
}
}
@ -347,47 +344,33 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
norm1 += accNorm1.reduceLanes(VectorOperators.ADD);
norm2 += accNorm2.reduceLanes(VectorOperators.ADD);
} else {
// 128-bit implementation, which must "split up" vectors due to widening conversions
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
IntVector accSum1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accSum2 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1_1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1_2 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2_1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2_2 = IntVector.zero(IntVector.SPECIES_128);
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
// 128-bit impl, which is tricky since we don't have SPECIES_32, it does "overlapping read"
int upperBound = ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
for (; i < upperBound; i += ByteVector.SPECIES_64.length() >> 1) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
// expand each byte vector into short vector and perform multiplications
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
Vector<Short> prod16 = va16.mul(vb16);
// process first half only
Vector<Short> va16 = va8.convert(VectorOperators.B2S, 0);
Vector<Short> vb16 = vb8.convert(VectorOperators.B2S, 0);
Vector<Short> norm1_16 = va16.mul(va16);
Vector<Short> norm2_16 = vb16.mul(vb16);
// split each short vector into two int vectors and add
Vector<Integer> prod32_1 =
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
Vector<Integer> prod32_2 =
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
Vector<Integer> norm1_32_1 =
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
Vector<Integer> norm1_32_2 =
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
Vector<Integer> norm2_32_1 =
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
Vector<Integer> norm2_32_2 =
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
accSum1 = accSum1.add(prod32_1);
accSum2 = accSum2.add(prod32_2);
accNorm1_1 = accNorm1_1.add(norm1_32_1);
accNorm1_2 = accNorm1_2.add(norm1_32_2);
accNorm2_1 = accNorm2_1.add(norm2_32_1);
accNorm2_2 = accNorm2_2.add(norm2_32_2);
Vector<Short> prod16 = va16.mul(vb16);
// sum into accumulators
accNorm1 =
accNorm1.add(norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
accNorm2 =
accNorm2.add(norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
accSum = accSum.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
}
// reduce
sum += accSum1.add(accSum2).reduceLanes(VectorOperators.ADD);
norm1 += accNorm1_1.add(accNorm1_2).reduceLanes(VectorOperators.ADD);
norm2 += accNorm2_1.add(accNorm2_2).reduceLanes(VectorOperators.ADD);
sum += accSum.reduceLanes(VectorOperators.ADD);
norm1 += accNorm1.reduceLanes(VectorOperators.ADD);
norm2 += accNorm2.reduceLanes(VectorOperators.ADD);
}
}