mirror of https://github.com/apache/lucene.git
Speed up vectorutil float scalar methods, unroll properly, use fma where possible (#12737)
Co-authored-by: Uwe Schindler <uschindler@apache.org>
This commit is contained in:
parent
b8a9b0ae29
commit
40e55b0ce7
|
@ -17,72 +17,46 @@
|
|||
|
||||
package org.apache.lucene.internal.vectorization;
|
||||
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
||||
final class DefaultVectorUtilSupport implements VectorUtilSupport {
|
||||
|
||||
DefaultVectorUtilSupport() {}
|
||||
|
||||
// the way FMA should work! if available use it, otherwise fall back to mul/add
|
||||
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
|
||||
private static float fma(float a, float b, float c) {
|
||||
if (Constants.HAS_FAST_SCALAR_FMA) {
|
||||
return Math.fma(a, b, c);
|
||||
} else {
|
||||
return a * b + c;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float dotProduct(float[] a, float[] b) {
|
||||
float res = 0f;
|
||||
/*
|
||||
* If length of vector is larger than 8, we use unrolled dot product to accelerate the
|
||||
* calculation.
|
||||
*/
|
||||
int i;
|
||||
for (i = 0; i < a.length % 8; i++) {
|
||||
res += b[i] * a[i];
|
||||
int i = 0;
|
||||
|
||||
// if the array is big, unroll it
|
||||
if (a.length > 32) {
|
||||
float acc1 = 0;
|
||||
float acc2 = 0;
|
||||
float acc3 = 0;
|
||||
float acc4 = 0;
|
||||
int upperBound = a.length & ~(4 - 1);
|
||||
for (; i < upperBound; i += 4) {
|
||||
acc1 = fma(a[i], b[i], acc1);
|
||||
acc2 = fma(a[i + 1], b[i + 1], acc2);
|
||||
acc3 = fma(a[i + 2], b[i + 2], acc3);
|
||||
acc4 = fma(a[i + 3], b[i + 3], acc4);
|
||||
}
|
||||
res += acc1 + acc2 + acc3 + acc4;
|
||||
}
|
||||
if (a.length < 8) {
|
||||
return res;
|
||||
}
|
||||
for (; i + 31 < a.length; i += 32) {
|
||||
res +=
|
||||
b[i + 0] * a[i + 0]
|
||||
+ b[i + 1] * a[i + 1]
|
||||
+ b[i + 2] * a[i + 2]
|
||||
+ b[i + 3] * a[i + 3]
|
||||
+ b[i + 4] * a[i + 4]
|
||||
+ b[i + 5] * a[i + 5]
|
||||
+ b[i + 6] * a[i + 6]
|
||||
+ b[i + 7] * a[i + 7];
|
||||
res +=
|
||||
b[i + 8] * a[i + 8]
|
||||
+ b[i + 9] * a[i + 9]
|
||||
+ b[i + 10] * a[i + 10]
|
||||
+ b[i + 11] * a[i + 11]
|
||||
+ b[i + 12] * a[i + 12]
|
||||
+ b[i + 13] * a[i + 13]
|
||||
+ b[i + 14] * a[i + 14]
|
||||
+ b[i + 15] * a[i + 15];
|
||||
res +=
|
||||
b[i + 16] * a[i + 16]
|
||||
+ b[i + 17] * a[i + 17]
|
||||
+ b[i + 18] * a[i + 18]
|
||||
+ b[i + 19] * a[i + 19]
|
||||
+ b[i + 20] * a[i + 20]
|
||||
+ b[i + 21] * a[i + 21]
|
||||
+ b[i + 22] * a[i + 22]
|
||||
+ b[i + 23] * a[i + 23];
|
||||
res +=
|
||||
b[i + 24] * a[i + 24]
|
||||
+ b[i + 25] * a[i + 25]
|
||||
+ b[i + 26] * a[i + 26]
|
||||
+ b[i + 27] * a[i + 27]
|
||||
+ b[i + 28] * a[i + 28]
|
||||
+ b[i + 29] * a[i + 29]
|
||||
+ b[i + 30] * a[i + 30]
|
||||
+ b[i + 31] * a[i + 31];
|
||||
}
|
||||
for (; i + 7 < a.length; i += 8) {
|
||||
res +=
|
||||
b[i + 0] * a[i + 0]
|
||||
+ b[i + 1] * a[i + 1]
|
||||
+ b[i + 2] * a[i + 2]
|
||||
+ b[i + 3] * a[i + 3]
|
||||
+ b[i + 4] * a[i + 4]
|
||||
+ b[i + 5] * a[i + 5]
|
||||
+ b[i + 6] * a[i + 6]
|
||||
+ b[i + 7] * a[i + 7];
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
res = fma(a[i], b[i], res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -92,50 +66,80 @@ final class DefaultVectorUtilSupport implements VectorUtilSupport {
|
|||
float sum = 0.0f;
|
||||
float norm1 = 0.0f;
|
||||
float norm2 = 0.0f;
|
||||
int dim = a.length;
|
||||
int i = 0;
|
||||
|
||||
for (int i = 0; i < dim; i++) {
|
||||
float elem1 = a[i];
|
||||
float elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
// if the array is big, unroll it
|
||||
if (a.length > 32) {
|
||||
float sum1 = 0;
|
||||
float sum2 = 0;
|
||||
float norm1_1 = 0;
|
||||
float norm1_2 = 0;
|
||||
float norm2_1 = 0;
|
||||
float norm2_2 = 0;
|
||||
|
||||
int upperBound = a.length & ~(2 - 1);
|
||||
for (; i < upperBound; i += 2) {
|
||||
// one
|
||||
sum1 = fma(a[i], b[i], sum1);
|
||||
norm1_1 = fma(a[i], a[i], norm1_1);
|
||||
norm2_1 = fma(b[i], b[i], norm2_1);
|
||||
|
||||
// two
|
||||
sum2 = fma(a[i + 1], b[i + 1], sum2);
|
||||
norm1_2 = fma(a[i + 1], a[i + 1], norm1_2);
|
||||
norm2_2 = fma(b[i + 1], b[i + 1], norm2_2);
|
||||
}
|
||||
sum += sum1 + sum2;
|
||||
norm1 += norm1_1 + norm1_2;
|
||||
norm2 += norm2_1 + norm2_2;
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
sum = fma(a[i], b[i], sum);
|
||||
norm1 = fma(a[i], a[i], norm1);
|
||||
norm2 = fma(b[i], b[i], norm2);
|
||||
}
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
@Override
|
||||
public float squareDistance(float[] a, float[] b) {
|
||||
float squareSum = 0.0f;
|
||||
int dim = a.length;
|
||||
int i;
|
||||
for (i = 0; i + 8 <= dim; i += 8) {
|
||||
squareSum += squareDistanceUnrolled(a, b, i);
|
||||
}
|
||||
for (; i < dim; i++) {
|
||||
float diff = a[i] - b[i];
|
||||
squareSum += diff * diff;
|
||||
}
|
||||
return squareSum;
|
||||
}
|
||||
float res = 0;
|
||||
int i = 0;
|
||||
|
||||
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
|
||||
float diff0 = v1[index + 0] - v2[index + 0];
|
||||
float diff1 = v1[index + 1] - v2[index + 1];
|
||||
float diff2 = v1[index + 2] - v2[index + 2];
|
||||
float diff3 = v1[index + 3] - v2[index + 3];
|
||||
float diff4 = v1[index + 4] - v2[index + 4];
|
||||
float diff5 = v1[index + 5] - v2[index + 5];
|
||||
float diff6 = v1[index + 6] - v2[index + 6];
|
||||
float diff7 = v1[index + 7] - v2[index + 7];
|
||||
return diff0 * diff0
|
||||
+ diff1 * diff1
|
||||
+ diff2 * diff2
|
||||
+ diff3 * diff3
|
||||
+ diff4 * diff4
|
||||
+ diff5 * diff5
|
||||
+ diff6 * diff6
|
||||
+ diff7 * diff7;
|
||||
// if the array is big, unroll it
|
||||
if (a.length > 32) {
|
||||
float acc1 = 0;
|
||||
float acc2 = 0;
|
||||
float acc3 = 0;
|
||||
float acc4 = 0;
|
||||
|
||||
int upperBound = a.length & ~(4 - 1);
|
||||
for (; i < upperBound; i += 4) {
|
||||
// one
|
||||
float diff1 = a[i] - b[i];
|
||||
acc1 = fma(diff1, diff1, acc1);
|
||||
|
||||
// two
|
||||
float diff2 = a[i + 1] - b[i + 1];
|
||||
acc2 = fma(diff2, diff2, acc2);
|
||||
|
||||
// three
|
||||
float diff3 = a[i + 2] - b[i + 2];
|
||||
acc3 = fma(diff3, diff3, acc3);
|
||||
|
||||
// four
|
||||
float diff4 = a[i + 3] - b[i + 3];
|
||||
acc4 = fma(diff4, diff4, acc4);
|
||||
}
|
||||
res += acc1 + acc2 + acc3 + acc4;
|
||||
}
|
||||
|
||||
for (; i < a.length; i++) {
|
||||
float diff = a[i] - b[i];
|
||||
res = fma(diff, diff, res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -18,7 +18,6 @@ package org.apache.lucene.util;
|
|||
|
||||
import java.security.AccessController;
|
||||
import java.security.PrivilegedAction;
|
||||
import java.util.Objects;
|
||||
import java.util.logging.Logger;
|
||||
|
||||
/** Some useful constants. */
|
||||
|
@ -67,12 +66,6 @@ public final class Constants {
|
|||
/** True iff running on a 64bit JVM */
|
||||
public static final boolean JRE_IS_64BIT = is64Bit();
|
||||
|
||||
/** true iff we know fast FMA is supported, to deliver less error */
|
||||
public static final boolean HAS_FAST_FMA =
|
||||
(IS_CLIENT_VM == false)
|
||||
&& Objects.equals(OS_ARCH, "amd64")
|
||||
&& HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
|
||||
|
||||
private static boolean is64Bit() {
|
||||
final String datamodel = getSysProp("sun.arch.data.model");
|
||||
if (datamodel != null) {
|
||||
|
@ -82,6 +75,76 @@ public final class Constants {
|
|||
}
|
||||
}
|
||||
|
||||
/** true if FMA likely means a cpu instruction and not BigDecimal logic */
|
||||
private static final boolean HAS_FMA =
|
||||
(IS_CLIENT_VM == false) && HotspotVMOptions.get("UseFMA").map(Boolean::valueOf).orElse(false);
|
||||
|
||||
/** maximum supported vectorsize */
|
||||
private static final int MAX_VECTOR_SIZE =
|
||||
HotspotVMOptions.get("MaxVectorSize").map(Integer::valueOf).orElse(0);
|
||||
|
||||
/** true for an AMD cpu with SSE4a instructions */
|
||||
private static final boolean HAS_SSE4A =
|
||||
HotspotVMOptions.get("UseXmmI2F").map(Boolean::valueOf).orElse(false);
|
||||
|
||||
/** true iff we know VFMA has faster throughput than separate vmul/vadd */
|
||||
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();
|
||||
|
||||
/** true iff we know FMA has faster throughput than separate mul/add */
|
||||
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();
|
||||
|
||||
private static boolean hasFastVectorFMA() {
|
||||
if (HAS_FMA) {
|
||||
String value = getSysProp("lucene.useVectorFMA", "auto");
|
||||
if ("auto".equals(value)) {
|
||||
// newer Neoverse cores have their act together
|
||||
// the problem is just apple silicon (this is a practical heuristic)
|
||||
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
|
||||
return true;
|
||||
}
|
||||
// zen cores or newer, its a wash, turn it on as it doesn't hurt
|
||||
// starts to yield gains for vectors only at zen4+
|
||||
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 32) {
|
||||
return true;
|
||||
}
|
||||
// intel has their act together
|
||||
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return Boolean.parseBoolean(value);
|
||||
}
|
||||
}
|
||||
// everyone else is slow, until proven otherwise by benchmarks
|
||||
return false;
|
||||
}
|
||||
|
||||
private static boolean hasFastScalarFMA() {
|
||||
if (HAS_FMA) {
|
||||
String value = getSysProp("lucene.useScalarFMA", "auto");
|
||||
if ("auto".equals(value)) {
|
||||
// newer Neoverse cores have their act together
|
||||
// the problem is just apple silicon (this is a practical heuristic)
|
||||
if (OS_ARCH.equals("aarch64") && MAC_OS_X == false) {
|
||||
return true;
|
||||
}
|
||||
// latency becomes 4 for the Zen3 (0x19h), but still a wash
|
||||
// until the Zen4 anyway, and big drop on previous zens:
|
||||
if (HAS_SSE4A && MAX_VECTOR_SIZE >= 64) {
|
||||
return true;
|
||||
}
|
||||
// intel has their act together
|
||||
if (OS_ARCH.equals("amd64") && HAS_SSE4A == false) {
|
||||
return true;
|
||||
}
|
||||
} else {
|
||||
return Boolean.parseBoolean(value);
|
||||
}
|
||||
}
|
||||
// everyone else is slow, until proven otherwise by benchmarks
|
||||
return false;
|
||||
}
|
||||
|
||||
private static String getSysProp(String property) {
|
||||
try {
|
||||
return doPrivileged(() -> System.getProperty(property));
|
||||
|
|
|
@ -20,7 +20,31 @@ package org.apache.lucene.util;
|
|||
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
|
||||
import org.apache.lucene.internal.vectorization.VectorizationProvider;
|
||||
|
||||
/** Utilities for computations with numeric arrays */
|
||||
/**
|
||||
* Utilities for computations with numeric arrays, especially algebraic operations like vector dot
|
||||
* products. This class uses SIMD vectorization if the corresponding Java module is available and
|
||||
* enabled. To enable vectorized code, pass {@code --add-modules jdk.incubator.vector} to Java's
|
||||
* command line.
|
||||
*
|
||||
* <p>It will use CPU's <a href="https://en.wikipedia.org/wiki/Fused_multiply%E2%80%93add">FMA
|
||||
* instructions</a> if it is known to perform faster than separate multiply+add. This requires at
|
||||
* least Hotspot C2 enabled, which is the default for OpenJDK based JVMs.
|
||||
*
|
||||
* <p>To explicitly disable or enable FMA usage, pass the following system properties:
|
||||
*
|
||||
* <ul>
|
||||
* <li>{@code -Dlucene.useScalarFMA=(auto|true|false)} for scalar operations
|
||||
* <li>{@code -Dlucene.useVectorFMA=(auto|true|false)} for vectorized operations (with vector
|
||||
* incubator module)
|
||||
* </ul>
|
||||
*
|
||||
* <p>The default is {@code auto}, which enables this for known CPU types and JVM settings. If
|
||||
* Hotspot C2 is disabled, FMA and vectorization are <strong>not</strong> used.
|
||||
*
|
||||
* <p>Vectorization and FMA is only supported for Hotspot-based JVMs; it won't work on OpenJ9-based
|
||||
* JVMs unless they provide {@link com.sun.management.HotSpotDiagnosticMXBean}. Please also make
|
||||
* sure that you have the {@code jdk.management} module enabled in modularized applications.
|
||||
*/
|
||||
public final class VectorUtil {
|
||||
|
||||
private static final VectorUtilSupport IMPL =
|
||||
|
|
|
@ -29,6 +29,7 @@ import jdk.incubator.vector.Vector;
|
|||
import jdk.incubator.vector.VectorShape;
|
||||
import jdk.incubator.vector.VectorSpecies;
|
||||
import org.apache.lucene.util.Constants;
|
||||
import org.apache.lucene.util.SuppressForbidden;
|
||||
|
||||
/**
|
||||
* VectorUtil methods implemented with Panama incubating vector API.
|
||||
|
@ -79,13 +80,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
// the way FMA should work! if available use it, otherwise fall back to mul/add
|
||||
private static FloatVector fma(FloatVector a, FloatVector b, FloatVector c) {
|
||||
if (Constants.HAS_FAST_FMA) {
|
||||
if (Constants.HAS_FAST_VECTOR_FMA) {
|
||||
return a.fma(b, c);
|
||||
} else {
|
||||
return a.mul(b).add(c);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressForbidden(reason = "Uses FMA only where fast and carefully contained")
|
||||
private static float fma(float a, float b, float c) {
|
||||
if (Constants.HAS_FAST_SCALAR_FMA) {
|
||||
return Math.fma(a, b, c);
|
||||
} else {
|
||||
return a * b + c;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public float dotProduct(float[] a, float[] b) {
|
||||
int i = 0;
|
||||
|
@ -99,7 +109,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
res += b[i] * a[i];
|
||||
res = fma(a[i], b[i], res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -165,11 +175,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
float elem1 = a[i];
|
||||
float elem2 = b[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
sum = fma(a[i], b[i], sum);
|
||||
norm1 = fma(a[i], a[i], norm1);
|
||||
norm2 = fma(b[i], b[i], norm2);
|
||||
}
|
||||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
@ -230,7 +238,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// scalar tail
|
||||
for (; i < a.length; i++) {
|
||||
float diff = a[i] - b[i];
|
||||
res += diff * diff;
|
||||
res = fma(diff, diff, res);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ final class PanamaVectorizationProvider extends VectorizationProvider {
|
|||
Locale.ENGLISH,
|
||||
"Java vector incubator API enabled; uses preferredBitSize=%d%s%s",
|
||||
PanamaVectorUtilSupport.VECTOR_BITSIZE,
|
||||
Constants.HAS_FAST_FMA ? "; FMA enabled" : "",
|
||||
Constants.HAS_FAST_VECTOR_FMA ? "; FMA enabled" : "",
|
||||
PanamaVectorUtilSupport.HAS_FAST_INTEGER_VECTORS
|
||||
? ""
|
||||
: "; floating-point vectors only"));
|
||||
|
|
Loading…
Reference in New Issue