diff --git a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index e509f6ef49c..b54136b69fa 100644 --- a/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java20/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -77,6 +77,47 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { VectorizationProvider.TESTS_FORCE_INTEGER_VECTORS || (isAMD64withoutAVX2 == false); } + private static final String MANAGEMENT_FACTORY_CLASS = "java.lang.management.ManagementFactory"; + private static final String HOTSPOT_BEAN_CLASS = "com.sun.management.HotSpotDiagnosticMXBean"; + + // best effort to see if FMA is fast (this is architecture-independent option) + private static boolean hasFastFMA() { + // on ARM cpus, FMA works fine but is a slight slowdown: don't use it. + if (Constants.OS_ARCH.equals("amd64") == false) { + return false; + } + try { + final Class beanClazz = Class.forName(HOTSPOT_BEAN_CLASS); + // we use reflection for this, because the management factory is not part + // of Java 8's compact profile: + final Object hotSpotBean = + Class.forName(MANAGEMENT_FACTORY_CLASS) + .getMethod("getPlatformMXBean", Class.class) + .invoke(null, beanClazz); + if (hotSpotBean != null) { + final var getVMOptionMethod = beanClazz.getMethod("getVMOption", String.class); + final Object vmOption = getVMOptionMethod.invoke(hotSpotBean, "UseFMA"); + return Boolean.parseBoolean( + vmOption.getClass().getMethod("getValue").invoke(vmOption).toString()); + } + return false; + } catch (@SuppressWarnings("unused") ReflectiveOperationException | RuntimeException e) { + return false; + } + } + + // true if we know FMA is supported, to deliver less error + private static final boolean HAS_FAST_FMA = hasFastFMA(); + + // the way FMA should work! if available use it, otherwise fall back to mul/add + private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) { + if (HAS_FAST_FMA) { + return a.fma(b, c); + } else { + return a.mul(b).add(c); + } + } + @Override public float dotProduct(float[] a, float[] b) { int i = 0; @@ -109,28 +150,28 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { // one FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); - acc1 = acc1.add(va.mul(vb)); + acc1 = fma(va, vb, acc1); // two FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); - acc2 = acc2.add(vc.mul(vd)); + acc2 = fma(vc, vd, acc2); // three FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length()); FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length()); - acc3 = acc3.add(ve.mul(vf)); + acc3 = fma(ve, vf, acc3); // four FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length()); FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); - acc4 = acc4.add(vg.mul(vh)); + acc4 = fma(vg, vh, acc4); } // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); - acc1 = acc1.add(va.mul(vb)); + acc1 = fma(va, vb, acc1); } // reduce FloatVector res1 = acc1.add(acc2); @@ -168,69 +209,42 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { /** vectorized cosine body */ private float[] cosineBody(float[] a, float[] b, int limit) { int i = 0; - // vector loop is unrolled 4x (4 accumulators in parallel) - // we don't know how many the cpu can do at once, some can do 2, some 4 + // vector loop is unrolled 2x (2 accumulators in parallel) + // each iteration has 3 FMAs, so its a lot already, no need to unroll more FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES); FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES); - FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES); - FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES); FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES); FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES); - FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES); - FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES); FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES); FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES); - FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES); - FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES); - int unrolledLimit = limit - 3 * FLOAT_SPECIES.length(); - for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) { + int unrolledLimit = limit - FLOAT_SPECIES.length(); + for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) { // one FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); - sum1 = sum1.add(va.mul(vb)); - norm1_1 = norm1_1.add(va.mul(va)); - norm2_1 = norm2_1.add(vb.mul(vb)); + sum1 = fma(va, vb, sum1); + norm1_1 = fma(va, va, norm1_1); + norm2_1 = fma(vb, vb, norm2_1); // two FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); - sum2 = sum2.add(vc.mul(vd)); - norm1_2 = norm1_2.add(vc.mul(vc)); - norm2_2 = norm2_2.add(vd.mul(vd)); - - // three - FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length()); - FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length()); - sum3 = sum3.add(ve.mul(vf)); - norm1_3 = norm1_3.add(ve.mul(ve)); - norm2_3 = norm2_3.add(vf.mul(vf)); - - // four - FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length()); - FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); - sum4 = sum4.add(vg.mul(vh)); - norm1_4 = norm1_4.add(vg.mul(vg)); - norm2_4 = norm2_4.add(vh.mul(vh)); + sum2 = fma(vc, vd, sum2); + norm1_2 = fma(vc, vc, norm1_2); + norm2_2 = fma(vd, vd, norm2_2); } // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); - sum1 = sum1.add(va.mul(vb)); - norm1_1 = norm1_1.add(va.mul(va)); - norm2_1 = norm2_1.add(vb.mul(vb)); + sum1 = fma(va, vb, sum1); + norm1_1 = fma(va, va, norm1_1); + norm2_1 = fma(vb, vb, norm2_1); } - // reduce - FloatVector sumres1 = sum1.add(sum2); - FloatVector sumres2 = sum3.add(sum4); - FloatVector norm1res1 = norm1_1.add(norm1_2); - FloatVector norm1res2 = norm1_3.add(norm1_4); - FloatVector norm2res1 = norm2_1.add(norm2_2); - FloatVector norm2res2 = norm2_3.add(norm2_4); return new float[] { - sumres1.add(sumres2).reduceLanes(ADD), - norm1res1.add(norm1res2).reduceLanes(ADD), - norm2res1.add(norm2res2).reduceLanes(ADD) + sum1.add(sum2).reduceLanes(ADD), + norm1_1.add(norm1_2).reduceLanes(ADD), + norm2_1.add(norm2_2).reduceLanes(ADD) }; } @@ -268,32 +282,32 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); FloatVector diff1 = va.sub(vb); - acc1 = acc1.add(diff1.mul(diff1)); + acc1 = fma(diff1, diff1, acc1); // two FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); FloatVector diff2 = vc.sub(vd); - acc2 = acc2.add(diff2.mul(diff2)); + acc2 = fma(diff2, diff2, acc2); // three FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length()); FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length()); FloatVector diff3 = ve.sub(vf); - acc3 = acc3.add(diff3.mul(diff3)); + acc3 = fma(diff3, diff3, acc3); // four FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length()); FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); FloatVector diff4 = vg.sub(vh); - acc4 = acc4.add(diff4.mul(diff4)); + acc4 = fma(diff4, diff4, acc4); } // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes for (; i < limit; i += FLOAT_SPECIES.length()) { FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); FloatVector diff = va.sub(vb); - acc1 = acc1.add(diff.mul(diff)); + acc1 = fma(diff, diff, acc1); } // reduce FloatVector res1 = acc1.add(acc2);