mirror of https://github.com/apache/lucene.git
Speedup integer functions for 128-bit neon vectors (#12632)
This commit is contained in:
parent
a4ff129de8
commit
872aee6d18
|
@ -279,27 +279,24 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// reduce
|
||||
res += acc.reduceLanes(VectorOperators.ADD);
|
||||
} else {
|
||||
// 128-bit implementation, which must "split up" vectors due to widening conversions
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
|
||||
IntVector acc1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector acc2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
|
||||
// 128-bit impl, which is tricky since we don't have SPECIES_32, it does "overlapping read"
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
|
||||
// 4 bytes at a time
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length() >> 1) {
|
||||
// load 8 bytes
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
// expand each byte vector into short vector and multiply
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
|
||||
// process first "half" only
|
||||
Vector<Short> va16 = va8.convert(VectorOperators.B2S, 0);
|
||||
Vector<Short> vb16 = vb8.convert(VectorOperators.B2S, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
// split each short vector into two int vectors and add
|
||||
Vector<Integer> prod32_1 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> prod32_2 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
acc1 = acc1.add(prod32_1);
|
||||
acc2 = acc2.add(prod32_2);
|
||||
|
||||
acc = acc.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
|
||||
}
|
||||
// reduce
|
||||
res += acc1.add(acc2).reduceLanes(VectorOperators.ADD);
|
||||
res += acc.reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -347,47 +344,33 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
norm1 += accNorm1.reduceLanes(VectorOperators.ADD);
|
||||
norm2 += accNorm2.reduceLanes(VectorOperators.ADD);
|
||||
} else {
|
||||
// 128-bit implementation, which must "split up" vectors due to widening conversions
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length);
|
||||
IntVector accSum1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accSum2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1_1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1_2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2_1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2_2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length()) {
|
||||
// 128-bit impl, which is tricky since we don't have SPECIES_32, it does "overlapping read"
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length() >> 1) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
// expand each byte vector into short vector and perform multiplications
|
||||
Vector<Short> va16 = va8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(VectorOperators.B2S, ShortVector.SPECIES_128, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// process first half only
|
||||
Vector<Short> va16 = va8.convert(VectorOperators.B2S, 0);
|
||||
Vector<Short> vb16 = vb8.convert(VectorOperators.B2S, 0);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
// split each short vector into two int vectors and add
|
||||
Vector<Integer> prod32_1 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> prod32_2 =
|
||||
prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
Vector<Integer> norm1_32_1 =
|
||||
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> norm1_32_2 =
|
||||
norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
Vector<Integer> norm2_32_1 =
|
||||
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0);
|
||||
Vector<Integer> norm2_32_2 =
|
||||
norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 1);
|
||||
accSum1 = accSum1.add(prod32_1);
|
||||
accSum2 = accSum2.add(prod32_2);
|
||||
accNorm1_1 = accNorm1_1.add(norm1_32_1);
|
||||
accNorm1_2 = accNorm1_2.add(norm1_32_2);
|
||||
accNorm2_1 = accNorm2_1.add(norm2_32_1);
|
||||
accNorm2_2 = accNorm2_2.add(norm2_32_2);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// sum into accumulators
|
||||
accNorm1 =
|
||||
accNorm1.add(norm1_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
|
||||
accNorm2 =
|
||||
accNorm2.add(norm2_16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
|
||||
accSum = accSum.add(prod16.convertShape(VectorOperators.S2I, IntVector.SPECIES_128, 0));
|
||||
}
|
||||
// reduce
|
||||
sum += accSum1.add(accSum2).reduceLanes(VectorOperators.ADD);
|
||||
norm1 += accNorm1_1.add(accNorm1_2).reduceLanes(VectorOperators.ADD);
|
||||
norm2 += accNorm2_1.add(accNorm2_2).reduceLanes(VectorOperators.ADD);
|
||||
sum += accSum.reduceLanes(VectorOperators.ADD);
|
||||
norm1 += accNorm1.reduceLanes(VectorOperators.ADD);
|
||||
norm2 += accNorm2.reduceLanes(VectorOperators.ADD);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue