mirror of https://github.com/apache/lucene.git
Speedup float cosine vectors, use FMA where fast and available to reduce error
This commit is contained in:
parent
09da2291c5
commit
a2016d1d50
|
@ -77,6 +77,47 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
VectorizationProvider.TESTS_FORCE_INTEGER_VECTORS || (isAMD64withoutAVX2 == false);
|
||||
}
|
||||
|
||||
private static final String MANAGEMENT_FACTORY_CLASS = "java.lang.management.ManagementFactory";
|
||||
private static final String HOTSPOT_BEAN_CLASS = "com.sun.management.HotSpotDiagnosticMXBean";
|
||||
|
||||
// best effort to see if FMA is fast (this is architecture-independent option)
|
||||
private static boolean hasFastFMA() {
|
||||
// on ARM cpus, FMA works fine but is a slight slowdown: don't use it.
|
||||
if (Constants.OS_ARCH.equals("amd64") == false) {
|
||||
return false;
|
||||
}
|
||||
try {
|
||||
final Class<?> beanClazz = Class.forName(HOTSPOT_BEAN_CLASS);
|
||||
// we use reflection for this, because the management factory is not part
|
||||
// of Java 8's compact profile:
|
||||
final Object hotSpotBean =
|
||||
Class.forName(MANAGEMENT_FACTORY_CLASS)
|
||||
.getMethod("getPlatformMXBean", Class.class)
|
||||
.invoke(null, beanClazz);
|
||||
if (hotSpotBean != null) {
|
||||
final var getVMOptionMethod = beanClazz.getMethod("getVMOption", String.class);
|
||||
final Object vmOption = getVMOptionMethod.invoke(hotSpotBean, "UseFMA");
|
||||
return Boolean.parseBoolean(
|
||||
vmOption.getClass().getMethod("getValue").invoke(vmOption).toString());
|
||||
}
|
||||
return false;
|
||||
} catch (@SuppressWarnings("unused") ReflectiveOperationException | RuntimeException e) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// true if we know FMA is supported, to deliver less error
|
||||
private static final boolean HAS_FAST_FMA = hasFastFMA();
|
||||
|
||||
// the way FMA should work! if available use it, otherwise fall back to mul/add
|
||||
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
|
||||
if (HAS_FAST_FMA) {
|
||||
return a.fma(b, c);
|
||||
} else {
|
||||
return a.mul(b).add(c);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float dotProduct(float[] a, float[] b) {
|
||||
int i = 0;
|
||||
|
@ -109,28 +150,28 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// one
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
acc1 = acc1.add(va.mul(vb));
|
||||
acc1 = fma(va, vb, acc1);
|
||||
|
||||
// two
|
||||
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
|
||||
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
|
||||
acc2 = acc2.add(vc.mul(vd));
|
||||
acc2 = fma(vc, vd, acc2);
|
||||
|
||||
// three
|
||||
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
|
||||
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
|
||||
acc3 = acc3.add(ve.mul(vf));
|
||||
acc3 = fma(ve, vf, acc3);
|
||||
|
||||
// four
|
||||
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
|
||||
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
|
||||
acc4 = acc4.add(vg.mul(vh));
|
||||
acc4 = fma(vg, vh, acc4);
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
for (; i < limit; i += FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
acc1 = acc1.add(va.mul(vb));
|
||||
acc1 = fma(va, vb, acc1);
|
||||
}
|
||||
// reduce
|
||||
FloatVector res1 = acc1.add(acc2);
|
||||
|
@ -168,69 +209,42 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
/** vectorized cosine body */
|
||||
private float[] cosineBody(float[] a, float[] b, int limit) {
|
||||
int i = 0;
|
||||
// vector loop is unrolled 4x (4 accumulators in parallel)
|
||||
// we don't know how many the cpu can do at once, some can do 2, some 4
|
||||
// vector loop is unrolled 2x (2 accumulators in parallel)
|
||||
// each iteration has 3 FMAs, so its a lot already, no need to unroll more
|
||||
FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
|
||||
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
|
||||
int unrolledLimit = limit - FLOAT_SPECIES.length();
|
||||
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
|
||||
// one
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
sum1 = fma(va, vb, sum1);
|
||||
norm1_1 = fma(va, va, norm1_1);
|
||||
norm2_1 = fma(vb, vb, norm2_1);
|
||||
|
||||
// two
|
||||
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
|
||||
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
|
||||
sum2 = sum2.add(vc.mul(vd));
|
||||
norm1_2 = norm1_2.add(vc.mul(vc));
|
||||
norm2_2 = norm2_2.add(vd.mul(vd));
|
||||
|
||||
// three
|
||||
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
|
||||
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
|
||||
sum3 = sum3.add(ve.mul(vf));
|
||||
norm1_3 = norm1_3.add(ve.mul(ve));
|
||||
norm2_3 = norm2_3.add(vf.mul(vf));
|
||||
|
||||
// four
|
||||
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
|
||||
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
|
||||
sum4 = sum4.add(vg.mul(vh));
|
||||
norm1_4 = norm1_4.add(vg.mul(vg));
|
||||
norm2_4 = norm2_4.add(vh.mul(vh));
|
||||
sum2 = fma(vc, vd, sum2);
|
||||
norm1_2 = fma(vc, vc, norm1_2);
|
||||
norm2_2 = fma(vd, vd, norm2_2);
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
for (; i < limit; i += FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
sum1 = fma(va, vb, sum1);
|
||||
norm1_1 = fma(va, va, norm1_1);
|
||||
norm2_1 = fma(vb, vb, norm2_1);
|
||||
}
|
||||
// reduce
|
||||
FloatVector sumres1 = sum1.add(sum2);
|
||||
FloatVector sumres2 = sum3.add(sum4);
|
||||
FloatVector norm1res1 = norm1_1.add(norm1_2);
|
||||
FloatVector norm1res2 = norm1_3.add(norm1_4);
|
||||
FloatVector norm2res1 = norm2_1.add(norm2_2);
|
||||
FloatVector norm2res2 = norm2_3.add(norm2_4);
|
||||
return new float[] {
|
||||
sumres1.add(sumres2).reduceLanes(ADD),
|
||||
norm1res1.add(norm1res2).reduceLanes(ADD),
|
||||
norm2res1.add(norm2res2).reduceLanes(ADD)
|
||||
sum1.add(sum2).reduceLanes(ADD),
|
||||
norm1_1.add(norm1_2).reduceLanes(ADD),
|
||||
norm2_1.add(norm2_2).reduceLanes(ADD)
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -268,32 +282,32 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
FloatVector diff1 = va.sub(vb);
|
||||
acc1 = acc1.add(diff1.mul(diff1));
|
||||
acc1 = fma(diff1, diff1, acc1);
|
||||
|
||||
// two
|
||||
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
|
||||
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
|
||||
FloatVector diff2 = vc.sub(vd);
|
||||
acc2 = acc2.add(diff2.mul(diff2));
|
||||
acc2 = fma(diff2, diff2, acc2);
|
||||
|
||||
// three
|
||||
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
|
||||
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
|
||||
FloatVector diff3 = ve.sub(vf);
|
||||
acc3 = acc3.add(diff3.mul(diff3));
|
||||
acc3 = fma(diff3, diff3, acc3);
|
||||
|
||||
// four
|
||||
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
|
||||
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
|
||||
FloatVector diff4 = vg.sub(vh);
|
||||
acc4 = acc4.add(diff4.mul(diff4));
|
||||
acc4 = fma(diff4, diff4, acc4);
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
for (; i < limit; i += FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
FloatVector diff = va.sub(vb);
|
||||
acc1 = acc1.add(diff.mul(diff));
|
||||
acc1 = fma(diff, diff, acc1);
|
||||
}
|
||||
// reduce
|
||||
FloatVector res1 = acc1.add(acc2);
|
||||
|
|
Loading…
Reference in New Issue