mirror of https://github.com/apache/lucene.git
stop using SPECIES_PREFERRED except at the top of this file
This commit is contained in:
parent
3ec9c26d67
commit
cce3cf106a
|
@ -31,24 +31,22 @@ import jdk.incubator.vector.VectorSpecies;
|
|||
|
||||
final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
||||
|
||||
private static final int INT_SPECIES_PREF_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize();
|
||||
|
||||
// we always use the platform's maximum floating point vector size
|
||||
// we always use the platform's maximum floating point/int vector size
|
||||
private static final VectorSpecies<Float> FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
|
||||
private static final VectorSpecies<Integer> INT_SPECIES = IntVector.SPECIES_PREFERRED;
|
||||
|
||||
// for integer methods, it is more complicated due to conversions
|
||||
private static final int INT_SPECIES_VSIZE = INT_SPECIES.vectorBitSize();
|
||||
private static final VectorSpecies<Byte> BYTE_SPECIES;
|
||||
private static final VectorSpecies<Short> SHORT_SPECIES;
|
||||
|
||||
// compute BYTE/SHORT sizes relative to preferred integer vector size
|
||||
static {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
if (INT_SPECIES_VSIZE >= 256) {
|
||||
BYTE_SPECIES =
|
||||
ByteVector.SPECIES_MAX.withShape(
|
||||
VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2));
|
||||
ByteVector.SPECIES_MAX.withShape(VectorShape.forBitSize(INT_SPECIES_VSIZE >> 2));
|
||||
SHORT_SPECIES =
|
||||
ShortVector.SPECIES_MAX.withShape(
|
||||
VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1));
|
||||
ShortVector.SPECIES_MAX.withShape(VectorShape.forBitSize(INT_SPECIES_VSIZE >> 1));
|
||||
} else {
|
||||
BYTE_SPECIES = null;
|
||||
SHORT_SPECIES = null;
|
||||
|
@ -306,10 +304,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && useIntegerVectors) {
|
||||
// compute vectorized dot product consistent with VPDPBUSD instruction
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
|
||||
if (INT_SPECIES_VSIZE >= 512) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
res += dotProductBody512(a, b, i);
|
||||
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
|
||||
} else if (INT_SPECIES_VSIZE == 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
res += dotProductBody256(a, b, i);
|
||||
} else {
|
||||
|
@ -328,7 +326,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
/** vectorized dot product body (512 bit vectors) */
|
||||
private int dotProductBody512(byte[] a, byte[] b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector acc = IntVector.zero(INT_SPECIES);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
@ -339,7 +337,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// 32-bit add
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES, 0);
|
||||
acc = acc.add(prod32);
|
||||
}
|
||||
// reduce
|
||||
|
@ -348,14 +346,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
/** vectorized dot product body (256 bit vectors) */
|
||||
private int dotProductBody256(byte[] a, byte[] b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
|
||||
// 32-bit multiply and add into accumulator
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_256, 0);
|
||||
acc = acc.add(va32.mul(vb32));
|
||||
}
|
||||
// reduce
|
||||
|
@ -394,10 +392,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && useIntegerVectors) {
|
||||
final float[] ret;
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
|
||||
if (INT_SPECIES_VSIZE >= 512) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
ret = cosineBody512(a, b, i);
|
||||
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
|
||||
} else if (INT_SPECIES_VSIZE == 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
ret = cosineBody256(a, b, i);
|
||||
} else {
|
||||
|
@ -423,9 +421,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
/** vectorized cosine body (512 bit vectors) */
|
||||
private float[] cosineBody512(byte[] a, byte[] b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accSum = IntVector.zero(INT_SPECIES);
|
||||
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
|
||||
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
@ -438,9 +436,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// sum into accumulators: 32-bit add
|
||||
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, INT_SPECIES, 0);
|
||||
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, INT_SPECIES, 0);
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES, 0);
|
||||
accNorm1 = accNorm1.add(norm1_32);
|
||||
accNorm2 = accNorm2.add(norm2_32);
|
||||
accSum = accSum.add(prod32);
|
||||
|
@ -453,16 +451,16 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
/** vectorized cosine body (256 bit vectors) */
|
||||
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
|
||||
// 16-bit multiply, and add into accumulators
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_256, 0);
|
||||
Vector<Integer> norm1_32 = va32.mul(va32);
|
||||
Vector<Integer> norm2_32 = vb32.mul(vb32);
|
||||
Vector<Integer> prod32 = va32.mul(vb32);
|
||||
|
@ -511,7 +509,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && useIntegerVectors) {
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
if (INT_SPECIES_VSIZE >= 256) {
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
res += squareDistanceBody256(a, b, i);
|
||||
} else {
|
||||
|
@ -530,15 +528,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
|
||||
/** vectorized square distance body (256+ bit vectors) */
|
||||
private int squareDistanceBody256(byte[] a, byte[] b, int limit) {
|
||||
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector acc = IntVector.zero(INT_SPECIES);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 32-bit sub, multiply, and add into accumulators
|
||||
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES, 0);
|
||||
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES, 0);
|
||||
Vector<Integer> diff32 = va32.sub(vb32);
|
||||
acc = acc.add(diff32.mul(diff32));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue