From 97b6c7b1bd1be3a3e7a2a3164181c6b9e392ca3d Mon Sep 17 00:00:00 2001 From: ChrisHegarty Date: Mon, 16 Dec 2024 16:32:21 +0000 Subject: [PATCH] Reduce unrolling in Panama dotProduct float variant --- .../PanamaVectorUtilSupport.java | 34 ++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java index 9273f7c5a81..a2f080d778b 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/PanamaVectorUtilSupport.java @@ -113,15 +113,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { /** vectorized float dot product body */ private float dotProductBody(float[] a, float[] b, int limit) { - int i = 0; - // vector loop is unrolled 4x (4 accumulators in parallel) - // we don't know how many the cpu can do at once, some can do 2, some 4 - FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES); - FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES); - FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES); - FloatVector acc4 = FloatVector.zero(FLOAT_SPECIES); - int unrolledLimit = limit - 3 * FLOAT_SPECIES.length(); - for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) { + // vector loop is unrolled 2x (2 accumulators in parallel) + FloatVector acc1 = + FloatVector.fromArray(FLOAT_SPECIES, a, 0).mul(FloatVector.fromArray(FLOAT_SPECIES, b, 0)); + FloatVector acc2 = + FloatVector.fromArray(FLOAT_SPECIES, a, FLOAT_SPECIES.length()) + .mul(FloatVector.fromArray(FLOAT_SPECIES, b, FLOAT_SPECIES.length())); + final int unrolledLimit = limit - FLOAT_SPECIES.length(); + int i = 2 * FLOAT_SPECIES.length(); + for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) { // one FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); @@ -131,27 +131,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport { FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length()); FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length()); acc2 = fma(vc, vd, acc2); - - // three - FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length()); - FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length()); - acc3 = fma(ve, vf, acc3); - - // four - FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length()); - FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length()); - acc4 = fma(vg, vh, acc4); } - // vector tail: less scalar computations for unaligned sizes, esp with big vector sizes - for (; i < limit; i += FLOAT_SPECIES.length()) { + if (i < limit) { FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i); FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i); acc1 = fma(va, vb, acc1); } // reduce FloatVector res1 = acc1.add(acc2); - FloatVector res2 = acc3.add(acc4); - return res1.add(res2).reduceLanes(ADD); + return res1.reduceLanes(ADD); } @Override