Speed up vectorutil float scalar methods, unroll properly, use fma where possible (#12737)

Co-authored-by: Uwe Schindler <uschindler@apache.org>
This commit is contained in:
Robert Muir 2023-11-04 19:25:58 -04:00 committed by GitHub
parent b8a9b0ae29
commit 40e55b0ce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 210 additions and 111 deletions

View File

@ -17,72 +17,46 @@
package org.apache.lucene.internal.vectorization; package org.apache.lucene.internal.vectorization;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;
final class DefaultVectorUtilSupport implements VectorUtilSupport { final class DefaultVectorUtilSupport implements VectorUtilSupport {
DefaultVectorUtilSupport() {} DefaultVectorUtilSupport() {}
// the way FMA should work! if available use it, otherwise fall back to mul/add
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}
@Override @Override
public float dotProduct(float[] a, float[] b) { public float dotProduct(float[] a, float[] b) {
float res = 0f; float res = 0f;
/* int i = 0;
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
* calculation. // if the array is big, unroll it
*/ if (a.length > 32) {
int i; float acc1 = 0;
for (i = 0; i < a.length % 8; i++) { float acc2 = 0;
res += b[i] * a[i]; float acc3 = 0;
float acc4 = 0;
int upperBound = a.length & ~(4 - 1);
for (; i < upperBound; i += 4) {
acc1 = fma(a[i], b[i], acc1);
acc2 = fma(a[i + 1], b[i + 1], acc2);
acc3 = fma(a[i + 2], b[i + 2], acc3);
acc4 = fma(a[i + 3], b[i + 3], acc4);
}
res += acc1 + acc2 + acc3 + acc4;
} }
if (a.length < 8) {
return res; for (; i < a.length; i++) {
} res = fma(a[i], b[i], res);
for (; i + 31 < a.length; i += 32) {
res +=
b[i + 0] * a[i + 0]
+ b[i + 1] * a[i + 1]
+ b[i + 2] * a[i + 2]
+ b[i + 3] * a[i + 3]
+ b[i + 4] * a[i + 4]
+ b[i + 5] * a[i + 5]
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
res +=
b[i + 8] * a[i + 8]
+ b[i + 9] * a[i + 9]
+ b[i + 10] * a[i + 10]
+ b[i + 11] * a[i + 11]
+ b[i + 12] * a[i + 12]
+ b[i + 13] * a[i + 13]
+ b[i + 14] * a[i + 14]
+ b[i + 15] * a[i + 15];
res +=
b[i + 16] * a[i + 16]
+ b[i + 17] * a[i + 17]
+ b[i + 18] * a[i + 18]
+ b[i + 19] * a[i + 19]
+ b[i + 20] * a[i + 20]
+ b[i + 21] * a[i + 21]
+ b[i + 22] * a[i + 22]
+ b[i + 23] * a[i + 23];
res +=
b[i + 24] * a[i + 24]
+ b[i + 25] * a[i + 25]
+ b[i + 26] * a[i + 26]
+ b[i + 27] * a[i + 27]
+ b[i + 28] * a[i + 28]
+ b[i + 29] * a[i + 29]
+ b[i + 30] * a[i + 30]
+ b[i + 31] * a[i + 31];
}
for (; i + 7 < a.length; i += 8) {
res +=
b[i + 0] * a[i + 0]
+ b[i + 1] * a[i + 1]
+ b[i + 2] * a[i + 2]
+ b[i + 3] * a[i + 3]
+ b[i + 4] * a[i + 4]
+ b[i + 5] * a[i + 5]
+ b[i + 6] * a[i + 6]
+ b[i + 7] * a[i + 7];
} }
return res; return res;
} }
@ -92,50 +66,80 @@ final class DefaultVectorUtilSupport implements VectorUtilSupport {
float sum = 0.0f; float sum = 0.0f;
float norm1 = 0.0f; float norm1 = 0.0f;
float norm2 = 0.0f; float norm2 = 0.0f;
int dim = a.length; int i = 0;
for (int i = 0; i < dim; i++) { // if the array is big, unroll it
float elem1 = a[i]; if (a.length > 32) {
float elem2 = b[i]; float sum1 = 0;
sum += elem1 * elem2; float sum2 = 0;
norm1 += elem1 * elem1; float norm1_1 = 0;
norm2 += elem2 * elem2; float norm1_2 = 0;
float norm2_1 = 0;
float norm2_2 = 0;
int upperBound = a.length & ~(2 - 1);
for (; i < upperBound; i += 2) {
// one
sum1 = fma(a[i], b[i], sum1);
norm1_1 = fma(a[i], a[i], norm1_1);
norm2_1 = fma(b[i], b[i], norm2_1);
// two
sum2 = fma(a[i + 1], b[i + 1], sum2);
norm1_2 = fma(a[i + 1], a[i + 1], norm1_2);
norm2_2 = fma(b[i + 1], b[i + 1], norm2_2);
}
sum += sum1 + sum2;
norm1 += norm1_1 + norm1_2;
norm2 += norm2_1 + norm2_2;
}
for (; i < a.length; i++) {
sum = fma(a[i], b[i], sum);
norm1 = fma(a[i], a[i], norm1);
norm2 = fma(b[i], b[i], norm2);
} }
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
} }
@Override @Override
public float squareDistance(float[] a, float[] b) { public float squareDistance(float[] a, float[] b) {
float squareSum = 0.0f; float res = 0;
int dim = a.length; int i = 0;
int i;
for (i = 0; i + 8 <= dim; i += 8) {
squareSum += squareDistanceUnrolled(a, b, i);
}
for (; i < dim; i++) {
float diff = a[i] - b[i];
squareSum += diff * diff;
}
return squareSum;
}
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) { // if the array is big, unroll it
float diff0 = v1[index + 0] - v2[index + 0]; if (a.length > 32) {
float diff1 = v1[index + 1] - v2[index + 1]; float acc1 = 0;
float diff2 = v1[index + 2] - v2[index + 2]; float acc2 = 0;
float diff3 = v1[index + 3] - v2[index + 3]; float acc3 = 0;
float diff4 = v1[index + 4] - v2[index + 4]; float acc4 = 0;
float diff5 = v1[index + 5] - v2[index + 5];
float diff6 = v1[index + 6] - v2[index + 6]; int upperBound = a.length & ~(4 - 1);
float diff7 = v1[index + 7] - v2[index + 7]; for (; i < upperBound; i += 4) {
return diff0 * diff0 // one
+ diff1 * diff1 float diff1 = a[i] - b[i];
+ diff2 * diff2 acc1 = fma(diff1, diff1, acc1);
+ diff3 * diff3
+ diff4 * diff4 // two
+ diff5 * diff5 float diff2 = a[i + 1] - b[i + 1];
+ diff6 * diff6 acc2 = fma(diff2, diff2, acc2);
+ diff7 * diff7;
// three
float diff3 = a[i + 2] - b[i + 2];
acc3 = fma(diff3, diff3, acc3);
// four
float diff4 = a[i + 3] - b[i + 3];
acc4 = fma(diff4, diff4, acc4);
}
res += acc1 + acc2 + acc3 + acc4;
}
for (; i < a.length; i++) {
float diff = a[i] - b[i];
res = fma(diff, diff, res);
}
return res;
} }
@Override @Override

View File

@ -18,7 +18,6 @@ package org.apache.lucene.util;
import java.security.AccessController; import java.security.AccessController;
import java.security.PrivilegedAction; import java.security.PrivilegedAction;
import java.util.Objects;
import java.util.logging.Logger; import java.util.logging.Logger;
/** Some useful constants. */ /** Some useful constants. */
@ -67,12 +66,6 @@ public final class Constants {
/** True iff running on a 64bit JVM */ /** True iff running on a 64bit JVM */
public static final boolean JRE_IS_64BIT = is64Bit(); public static final boolean JRE_IS_64BIT = is64Bit();
/** true iff we know fast FMA is supported, to deliver less error */
public static final boolean HAS_FAST_FMA =
(IS_CLIENT_VM == false)
&& Objects.equals(OS_ARCH, "amd64")
&& HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
private static boolean is64Bit() { private static boolean is64Bit() {
final String datamodel = getSysProp("sun.arch.data.model"); final String datamodel = getSysProp("sun.arch.data.model");
if (datamodel != null) { if (datamodel != null) {
@ -82,6 +75,76 @@ public final class Constants {
} }
} }
/** true if FMA likely means a cpu instruction and not BigDecimal logic */
private static final boolean HAS_FMA =
(IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
/** maximum supported vectorsize */
private static final int MAX_VECTOR_SIZE =
HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0);
/** true for an AMD cpu with SSE4a instructions */
private static final boolean HAS_SSE4A =
HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false);
/** true iff we know VFMA has faster throughput than separate vmul/vadd */
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();
/** true iff we know FMA has faster throughput than separate mul/add */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();
private static boolean hasFastVectorFMA() {
if (HAS_FMA) {
String value = getSysProp("lucene.useVectorFMA", "auto");
if ("auto".equals(value)) {
// newer Neoverse cores have their act together
// the problem is just apple silicon (this is a practical heuristic)
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
return true;
}
// zen cores or newer, its a wash, turn it on as it doesn't hurt
// starts to yield gains for vectors only at zen4+
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) {
return true;
}
// intel has their act together
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
return true;
}
} else {
return Boolean.parseBoolean(value);
}
}
// everyone else is slow, until proven otherwise by benchmarks
return false;
}
private static boolean hasFastScalarFMA() {
if (HAS_FMA) {
String value = getSysProp("lucene.useScalarFMA", "auto");
if ("auto".equals(value)) {
// newer Neoverse cores have their act together
// the problem is just apple silicon (this is a practical heuristic)
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
return true;
}
// latency becomes 4 for the Zen3 (0x19h), but still a wash
// until the Zen4 anyway, and big drop on previous zens:
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) {
return true;
}
// intel has their act together
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
return true;
}
} else {
return Boolean.parseBoolean(value);
}
}
// everyone else is slow, until proven otherwise by benchmarks
return false;
}
private static String getSysProp(String property) { private static String getSysProp(String property) {
try { try {
return doPrivileged(() -> System.getProperty(property)); return doPrivileged(() -> System.getProperty(property));

View File

@ -20,7 +20,31 @@ package org.apache.lucene.util;
import org.apache.lucene.internal.vectorization.VectorUtilSupport; import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider; import org.apache.lucene.internal.vectorization.VectorizationProvider;
/** Utilities for computations with numeric arrays */ /**
* Utilities for computations with numeric arrays, especially algebraic operations like vector dot
* products. This class uses SIMD vectorization if the corresponding Java module is available and
* enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's
* command line.
*
* <p>It will use CPU's <a href="https://en.wikipedia.org/wiki/Fused_multiply%E2%80%93add">FMA
* instructions</a> if it is known to perform faster than separate multiply+add. This requires at
* least Hotspot C2 enabled, which is the default for OpenJDK based JVMs.
*
* <p>To explicitly disable or enable FMA usage, pass the following system properties:
*
* <ul>
* <li>{@code -Dlucene.useScalarFMA=(auto|true|false)} for scalar operations
* <li>{@code -Dlucene.useVectorFMA=(auto|true|false)} for vectorized operations (with vector
* incubator module)
* </ul>
*
* <p>The default is {@code auto}, which enables this for known CPU types and JVM settings. If
* Hotspot C2 is disabled, FMA and vectorization are <strong>not</strong> used.
*
* <p>Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based
* JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make
* sure that you have the {@code jdk.management} module enabled in modularized applications.
*/
public final class VectorUtil { public final class VectorUtil {
private static final VectorUtilSupport IMPL = private static final VectorUtilSupport IMPL =

View File

@ -29,6 +29,7 @@ import jdk.incubator.vector.Vector;
import jdk.incubator.vector.VectorShape; import jdk.incubator.vector.VectorShape;
import jdk.incubator.vector.VectorSpecies; import jdk.incubator.vector.VectorSpecies;
import org.apache.lucene.util.Constants; import org.apache.lucene.util.Constants;
import org.apache.lucene.util.SuppressForbidden;
/** /**
* VectorUtil methods implemented with Panama incubating vector API. * VectorUtil methods implemented with Panama incubating vector API.
@ -79,13 +80,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// the way FMA should work! if available use it, otherwise fall back to mul/add // the way FMA should work! if available use it, otherwise fall back to mul/add
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) { private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
if (Constants.HAS_FAST_FMA) { if (Constants.HAS_FAST_VECTOR_FMA) {
return a.fma(b, c); return a.fma(b, c);
} else { } else {
return a.mul(b).add(c); return a.mul(b).add(c);
} }
} }
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
private static float fma(float a, float b, float c) {
if (Constants.HAS_FAST_SCALAR_FMA) {
return Math.fma(a, b, c);
} else {
return a * b + c;
}
}
@Override @Override
public float dotProduct(float[] a, float[] b) { public float dotProduct(float[] a, float[] b) {
int i = 0; int i = 0;
@ -99,7 +109,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// scalar tail // scalar tail
for (; i < a.length; i++) { for (; i < a.length; i++) {
res += b[i] * a[i]; res = fma(a[i], b[i], res);
} }
return res; return res;
} }
@ -165,11 +175,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// scalar tail // scalar tail
for (; i < a.length; i++) { for (; i < a.length; i++) {
float elem1 = a[i]; sum = fma(a[i], b[i], sum);
float elem2 = b[i]; norm1 = fma(a[i], a[i], norm1);
sum += elem1 * elem2; norm2 = fma(b[i], b[i], norm2);
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
} }
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2)); return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
} }
@ -230,7 +238,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// scalar tail // scalar tail
for (; i < a.length; i++) { for (; i < a.length; i++) {
float diff = a[i] - b[i]; float diff = a[i] - b[i];
res += diff * diff; res = fma(diff, diff, res);
} }
return res; return res;
} }

View File

@ -63,7 +63,7 @@ final class PanamaVectorizationProvider extends VectorizationProvider {
Locale.ENGLISH, Locale.ENGLISH,
"Java vector incubator API enabled; uses preferredBitSize=%d%s%s", "Java vector incubator API enabled; uses preferredBitSize=%d%s%s",
PanamaVectorUtilSupport.VECTOR_BITSIZE, PanamaVectorUtilSupport.VECTOR_BITSIZE,
Constants.HAS_FAST_FMA ? "; FMA enabled" : "", Constants.HAS_FAST_VECTOR_FMA ? "; FMA enabled" : "",
PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
? "" ? ""
: "; floating-point vectors only")); : "; floating-point vectors only"));