Reduce unrolling in Panama dotProduct float variant

This commit is contained in:
ChrisHegarty 2024-12-16 16:32:21 +00:00
parent 084480dffb
commit 97b6c7b1bd
1 changed files with 11 additions and 23 deletions

View File

@ -113,15 +113,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized float dot product body */
private float dotProductBody(float[] a, float[] b, int limit) {
int i = 0;
// vector loop is unrolled 4x (4 accumulators in parallel)
// we don't know how many the cpu can do at once, some can do 2, some 4
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
// vector loop is unrolled 2x (2 accumulators in parallel)
FloatVector acc1 =
FloatVector.fromArray(FLOAT_SPECIES, a, 0).mul(FloatVector.fromArray(FLOAT_SPECIES, b, 0));
FloatVector acc2 =
FloatVector.fromArray(FLOAT_SPECIES, a, FLOAT_SPECIES.length())
.mul(FloatVector.fromArray(FLOAT_SPECIES, b, FLOAT_SPECIES.length()));
final int unrolledLimit = limit - FLOAT_SPECIES.length();
int i = 2 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
@ -131,27 +131,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
acc2 = fma(vc, vd, acc2);
// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
acc3 = fma(ve, vf, acc3);
// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
acc4 = fma(vg, vh, acc4);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
if (i < limit) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
acc1 = fma(va, vb, acc1);
}
// reduce
FloatVector res1 = acc1.add(acc2);
FloatVector res2 = acc3.add(acc4);
return res1.add(res2).reduceLanes(ADD);
return res1.reduceLanes(ADD);
}
@Override