stop using SPECIES_PREFERRED except at the top of this file

This commit is contained in:
Robert Muir 2023-10-14 18:08:33 -04:00
parent 3ec9c26d67
commit cce3cf106a
No known key found for this signature in database
GPG Key ID: 817AE1DD322D7ECA
1 changed files with 36 additions and 38 deletions

View File

@ -31,24 +31,22 @@ import jdk.incubator.vector.VectorSpecies;
final class PanamaVectorUtilSupport implements VectorUtilSupport {
private static final int INT_SPECIES_PREF_BIT_SIZE = IntVector.SPECIES_PREFERRED.vectorBitSize();
// we always use the platform's maximum floating point vector size
// we always use the platform's maximum floating point/int vector size
private static final VectorSpecies<Float> FLOAT_SPECIES = FloatVector.SPECIES_PREFERRED;
private static final VectorSpecies<Integer> INT_SPECIES = IntVector.SPECIES_PREFERRED;
// for integer methods, it is more complicated due to conversions
private static final int INT_SPECIES_VSIZE = INT_SPECIES.vectorBitSize();
private static final VectorSpecies<Byte> BYTE_SPECIES;
private static final VectorSpecies<Short> SHORT_SPECIES;
// compute BYTE/SHORT sizes relative to preferred integer vector size
static {
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
if (INT_SPECIES_VSIZE >= 256) {
BYTE_SPECIES =
ByteVector.SPECIES_MAX.withShape(
VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 2));
ByteVector.SPECIES_MAX.withShape(VectorShape.forBitSize(INT_SPECIES_VSIZE >> 2));
SHORT_SPECIES =
ShortVector.SPECIES_MAX.withShape(
VectorShape.forBitSize(IntVector.SPECIES_PREFERRED.vectorBitSize() >> 1));
ShortVector.SPECIES_MAX.withShape(VectorShape.forBitSize(INT_SPECIES_VSIZE >> 1));
} else {
BYTE_SPECIES = null;
SHORT_SPECIES = null;
@ -306,10 +304,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && useIntegerVectors) {
// compute vectorized dot product consistent with VPDPBUSD instruction
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
if (INT_SPECIES_VSIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.length);
res += dotProductBody512(a, b, i);
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
} else if (INT_SPECIES_VSIZE == 256) {
i += BYTE_SPECIES.loopBound(a.length);
res += dotProductBody256(a, b, i);
} else {
@ -328,7 +326,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized dot product body (512 bit vectors) */
private int dotProductBody512(byte[] a, byte[] b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector acc = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
@ -339,7 +337,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
Vector<Short> prod16 = va16.mul(vb16);
// 32-bit add
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES, 0);
acc = acc.add(prod32);
}
// reduce
@ -348,14 +346,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized dot product body (256 bit vectors) */
private int dotProductBody256(byte[] a, byte[] b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
IntVector acc = IntVector.zero(IntVector.SPECIES_256);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
// 32-bit multiply and add into accumulator
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_256, 0);
acc = acc.add(va32.mul(vb32));
}
// reduce
@ -394,10 +392,10 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && useIntegerVectors) {
final float[] ret;
if (INT_SPECIES_PREF_BIT_SIZE >= 512) {
if (INT_SPECIES_VSIZE >= 512) {
i += BYTE_SPECIES.loopBound(a.length);
ret = cosineBody512(a, b, i);
} else if (INT_SPECIES_PREF_BIT_SIZE == 256) {
} else if (INT_SPECIES_VSIZE == 256) {
i += BYTE_SPECIES.loopBound(a.length);
ret = cosineBody256(a, b, i);
} else {
@ -423,9 +421,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized cosine body (512 bit vectors) */
private float[] cosineBody512(byte[] a, byte[] b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accSum = IntVector.zero(INT_SPECIES);
IntVector accNorm1 = IntVector.zero(INT_SPECIES);
IntVector accNorm2 = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
@ -438,9 +436,9 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
Vector<Short> prod16 = va16.mul(vb16);
// sum into accumulators: 32-bit add
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, INT_SPECIES, 0);
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, INT_SPECIES, 0);
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES, 0);
accNorm1 = accNorm1.add(norm1_32);
accNorm2 = accNorm2.add(norm2_32);
accSum = accSum.add(prod32);
@ -453,16 +451,16 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized cosine body (256 bit vectors) */
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
IntVector accSum = IntVector.zero(IntVector.SPECIES_256);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_256);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_256);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length()) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
// 16-bit multiply, and add into accumulators
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_256, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_256, 0);
Vector<Integer> norm1_32 = va32.mul(va32);
Vector<Integer> norm2_32 = vb32.mul(vb32);
Vector<Integer> prod32 = va32.mul(vb32);
@ -511,7 +509,7 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && useIntegerVectors) {
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
if (INT_SPECIES_VSIZE >= 256) {
i += BYTE_SPECIES.loopBound(a.length);
res += squareDistanceBody256(a, b, i);
} else {
@ -530,15 +528,15 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
/** vectorized square distance body (256+ bit vectors) */
private int squareDistanceBody256(byte[] a, byte[] b, int limit) {
IntVector acc = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector acc = IntVector.zero(INT_SPECIES);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 32-bit sub, multiply, and add into accumulators
// TODO: uses AVX-512 heavy multiply on zmm, should we just use 256-bit vectors on AVX-512?
Vector<Integer> va32 = va8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES, 0);
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES, 0);
Vector<Integer> diff32 = va32.sub(vb32);
acc = acc.add(diff32.mul(diff32));
}