Speedup float cosine vectors, use FMA where fast and available to reduce error

This commit is contained in:
Robert Muir 2023-10-27 23:45:14 -04:00
parent 09da2291c5
commit a2016d1d50
No known key found for this signature in database
GPG Key ID: 817AE1DD322D7ECA
1 changed files with 67 additions and 53 deletions

View File

@ -77,6 +77,47 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
VectorizationProvider.TESTS_FORCE_INTEGER_VECTORS || (isAMD64withoutAVX2 == false);
}
private static final String MANAGEMENT_FACTORY_CLASS = "java.lang.management.ManagementFactory";
private static final String HOTSPOT_BEAN_CLASS = "com.sun.management.HotSpotDiagnosticMXBean";
// best effort to see if FMA is fast (this is architecture-independent option)
private static boolean hasFastFMA() {
// on ARM cpus, FMA works fine but is a slight slowdown: don't use it.
if (Constants.OS_ARCH.equals("amd64") == false) {
return false;
}
try {
final Class<?> beanClazz = Class.forName(HOTSPOT_BEAN_CLASS);
// we use reflection for this, because the management factory is not part
// of Java 8's compact profile:
final Object hotSpotBean =
Class.forName(MANAGEMENT_FACTORY_CLASS)
.getMethod("getPlatformMXBean", Class.class)
.invoke(null, beanClazz);
if (hotSpotBean != null) {
final var getVMOptionMethod = beanClazz.getMethod("getVMOption", String.class);
final Object vmOption = getVMOptionMethod.invoke(hotSpotBean, "UseFMA");
return Boolean.parseBoolean(
vmOption.getClass().getMethod("getValue").invoke(vmOption).toString());
}
return false;
} catch (@SuppressWarnings("unused") ReflectiveOperationException | RuntimeException e) {
return false;
}
}
// true if we know FMA is supported, to deliver less error
private static final boolean HAS_FAST_FMA = hasFastFMA();
// the way FMA should work! if available use it, otherwise fall back to mul/add
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
if (HAS_FAST_FMA) {
return a.fma(b, c);
} else {
return a.mul(b).add(c);
}
}
@Override
public float dotProduct(float[] a, float[] b) {
int i = 0;
@ -109,28 +150,28 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
acc1 = acc1.add(va.mul(vb));
acc1 = fma(va, vb, acc1);
// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
acc2 = acc2.add(vc.mul(vd));
acc2 = fma(vc, vd, acc2);
// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
acc3 = acc3.add(ve.mul(vf));
acc3 = fma(ve, vf, acc3);
// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
acc4 = acc4.add(vg.mul(vh));
acc4 = fma(vg, vh, acc4);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
acc1 = acc1.add(va.mul(vb));
acc1 = fma(va, vb, acc1);
}
// reduce
FloatVector res1 = acc1.add(acc2);
@ -168,69 +209,42 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized cosine body */
private float[] cosineBody(float[] a, float[] b, int limit) {
int i = 0;
// vector loop is unrolled 4x (4 accumulators in parallel)
// we don't know how many the cpu can do at once, some can do 2, some 4
// vector loop is unrolled 2x (2 accumulators in parallel)
// each iteration has 3 FMAs, so its a lot already, no need to unroll more
FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
int unrolledLimit = limit - FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = sum1.add(va.mul(vb));
norm1_1 = norm1_1.add(va.mul(va));
norm2_1 = norm2_1.add(vb.mul(vb));
sum1 = fma(va, vb, sum1);
norm1_1 = fma(va, va, norm1_1);
norm2_1 = fma(vb, vb, norm2_1);
// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
sum2 = sum2.add(vc.mul(vd));
norm1_2 = norm1_2.add(vc.mul(vc));
norm2_2 = norm2_2.add(vd.mul(vd));
// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
sum3 = sum3.add(ve.mul(vf));
norm1_3 = norm1_3.add(ve.mul(ve));
norm2_3 = norm2_3.add(vf.mul(vf));
// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
sum4 = sum4.add(vg.mul(vh));
norm1_4 = norm1_4.add(vg.mul(vg));
norm2_4 = norm2_4.add(vh.mul(vh));
sum2 = fma(vc, vd, sum2);
norm1_2 = fma(vc, vc, norm1_2);
norm2_2 = fma(vd, vd, norm2_2);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = sum1.add(va.mul(vb));
norm1_1 = norm1_1.add(va.mul(va));
norm2_1 = norm2_1.add(vb.mul(vb));
sum1 = fma(va, vb, sum1);
norm1_1 = fma(va, va, norm1_1);
norm2_1 = fma(vb, vb, norm2_1);
}
// reduce
FloatVector sumres1 = sum1.add(sum2);
FloatVector sumres2 = sum3.add(sum4);
FloatVector norm1res1 = norm1_1.add(norm1_2);
FloatVector norm1res2 = norm1_3.add(norm1_4);
FloatVector norm2res1 = norm2_1.add(norm2_2);
FloatVector norm2res2 = norm2_3.add(norm2_4);
return new float[] {
sumres1.add(sumres2).reduceLanes(ADD),
norm1res1.add(norm1res2).reduceLanes(ADD),
norm2res1.add(norm2res2).reduceLanes(ADD)
sum1.add(sum2).reduceLanes(ADD),
norm1_1.add(norm1_2).reduceLanes(ADD),
norm2_1.add(norm2_2).reduceLanes(ADD)
};
}
@ -268,32 +282,32 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
FloatVector diff1 = va.sub(vb);
acc1 = acc1.add(diff1.mul(diff1));
acc1 = fma(diff1, diff1, acc1);
// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
FloatVector diff2 = vc.sub(vd);
acc2 = acc2.add(diff2.mul(diff2));
acc2 = fma(diff2, diff2, acc2);
// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
FloatVector diff3 = ve.sub(vf);
acc3 = acc3.add(diff3.mul(diff3));
acc3 = fma(diff3, diff3, acc3);
// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
FloatVector diff4 = vg.sub(vh);
acc4 = acc4.add(diff4.mul(diff4));
acc4 = fma(diff4, diff4, acc4);
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
FloatVector diff = va.sub(vb);
acc1 = acc1.add(diff.mul(diff));
acc1 = fma(diff, diff, acc1);
}
// reduce
FloatVector res1 = acc1.add(acc2);