cleanup cosine too, no perf impact

This commit is contained in:
Robert Muir 2023-10-14 14:10:32 -04:00
parent a72bf7ce68
commit e93bd524cf
No known key found for this signature in database
GPG Key ID: 817AE1DD322D7ECA
1 changed files with 143 additions and 116 deletions

View File

@ -127,71 +127,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
float sum = 0;
float norm1 = 0;
float norm2 = 0;
// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
if (a.length > 2 * FLOAT_SPECIES.length()) {
// vector loop is unrolled 4x (4 accumulators in parallel)
// we don't know how many the cpu can do at once, some can do 2, some 4
FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES);
int upperBound = FLOAT_SPECIES.loopBound(a.length - 3 * FLOAT_SPECIES.length());
for (; i < upperBound; i += 4 * FLOAT_SPECIES.length()) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = sum1.add(va.mul(vb));
norm1_1 = norm1_1.add(va.mul(va));
norm2_1 = norm2_1.add(vb.mul(vb));
// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
sum2 = sum2.add(vc.mul(vd));
norm1_2 = norm1_2.add(vc.mul(vc));
norm2_2 = norm2_2.add(vd.mul(vd));
// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
sum3 = sum3.add(ve.mul(vf));
norm1_3 = norm1_3.add(ve.mul(ve));
norm2_3 = norm2_3.add(vf.mul(vf));
// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
sum4 = sum4.add(vg.mul(vh));
norm1_4 = norm1_4.add(vg.mul(vg));
norm2_4 = norm2_4.add(vh.mul(vh));
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
upperBound = FLOAT_SPECIES.loopBound(a.length);
for (; i < upperBound; i += FLOAT_SPECIES.length()) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = sum1.add(va.mul(vb));
norm1_1 = norm1_1.add(va.mul(va));
norm2_1 = norm2_1.add(vb.mul(vb));
}
// reduce
FloatVector sumres1 = sum1.add(sum2);
FloatVector sumres2 = sum3.add(sum4);
FloatVector norm1res1 = norm1_1.add(norm1_2);
FloatVector norm1res2 = norm1_3.add(norm1_4);
FloatVector norm2res1 = norm2_1.add(norm2_2);
FloatVector norm2res2 = norm2_3.add(norm2_4);
sum += sumres1.add(sumres2).reduceLanes(ADD);
norm1 += norm1res1.add(norm1res2).reduceLanes(ADD);
norm2 += norm2res1.add(norm2res2).reduceLanes(ADD);
i += FLOAT_SPECIES.loopBound(a.length);
float[] ret = cosineBody(a, b, i);
sum += ret[0];
norm1 += ret[1];
norm2 += ret[2];
}
// scalar tail
@ -205,6 +148,75 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
/** vectorized cosine body */
private float[] cosineBody(float[] a, float[] b, int limit) {
int i = 0;
// vector loop is unrolled 4x (4 accumulators in parallel)
// we don't know how many the cpu can do at once, some can do 2, some 4
FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
// one
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = sum1.add(va.mul(vb));
norm1_1 = norm1_1.add(va.mul(va));
norm2_1 = norm2_1.add(vb.mul(vb));
// two
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
sum2 = sum2.add(vc.mul(vd));
norm1_2 = norm1_2.add(vc.mul(vc));
norm2_2 = norm2_2.add(vd.mul(vd));
// three
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
sum3 = sum3.add(ve.mul(vf));
norm1_3 = norm1_3.add(ve.mul(ve));
norm2_3 = norm2_3.add(vf.mul(vf));
// four
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
sum4 = sum4.add(vg.mul(vh));
norm1_4 = norm1_4.add(vg.mul(vg));
norm2_4 = norm2_4.add(vh.mul(vh));
}
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
sum1 = sum1.add(va.mul(vb));
norm1_1 = norm1_1.add(va.mul(va));
norm2_1 = norm2_1.add(vb.mul(vb));
}
// reduce
FloatVector sumres1 = sum1.add(sum2);
FloatVector sumres2 = sum3.add(sum4);
FloatVector norm1res1 = norm1_1.add(norm1_2);
FloatVector norm1res2 = norm1_3.add(norm1_4);
FloatVector norm2res1 = norm2_1.add(norm2_2);
FloatVector norm2res2 = norm2_3.add(norm2_4);
return new float[] {
sumres1.add(sumres2).reduceLanes(ADD),
norm1res1.add(norm1res2).reduceLanes(ADD),
norm2res1.add(norm2res2).reduceLanes(ADD)
};
}
@Override
public float squareDistance(float[] a, float[] b) {
int i = 0;
@ -357,65 +369,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
int sum = 0;
int norm1 = 0;
int norm2 = 0;
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
// vectors (256-bit on intel to dodge performance landmines)
if (a.length >= 16 && useIntegerVectors) {
final float[] ret;
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
int upperBound = BYTE_SPECIES.loopBound(a.length);
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (; i < upperBound; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 16-bit multiply
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> prod16 = va16.mul(vb16);
Vector<Short> norm1_16 = va16.mul(va16);
Vector<Short> norm2_16 = vb16.mul(vb16);
// sum into accumulators: 32-bit add
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
accSum = accSum.add(prod32);
accNorm1 = accNorm1.add(norm1_32);
accNorm2 = accNorm2.add(norm2_32);
}
// reduce
sum += accSum.reduceLanes(ADD);
norm1 += accNorm1.reduceLanes(ADD);
norm2 += accNorm2.reduceLanes(ADD);
i += BYTE_SPECIES.loopBound(a.length);
ret = cosineBody256(a, b, i);
} else {
// 128-bit impl, which is tricky since we don't have SPECIES_32, it does "overlapping read"
int upperBound = ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
for (; i < upperBound; i += ByteVector.SPECIES_64.length() >> 1) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
// process first half only: 16-bit multiply
Vector<Short> va16 = va8.convert(B2S, 0);
Vector<Short> vb16 = vb8.convert(B2S, 0);
Vector<Short> norm1_16 = va16.mul(va16);
Vector<Short> norm2_16 = vb16.mul(vb16);
Vector<Short> prod16 = va16.mul(vb16);
// sum into accumulators: 32-bit add
accNorm1 = accNorm1.add(norm1_16.convertShape(S2I, IntVector.SPECIES_128, 0));
accNorm2 = accNorm2.add(norm2_16.convertShape(S2I, IntVector.SPECIES_128, 0));
accSum = accSum.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
}
// reduce
sum += accSum.reduceLanes(ADD);
norm1 += accNorm1.reduceLanes(ADD);
norm2 += accNorm2.reduceLanes(ADD);
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
ret = cosineBody128(a, b, i);
}
sum += ret[0];
norm1 += ret[1];
norm2 += ret[2];
}
// scalar tail
@ -429,6 +398,64 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
}
/** vectorized cosine body (256+ bit vectors) */
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
// 16-bit multiply
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
Vector<Short> prod16 = va16.mul(vb16);
Vector<Short> norm1_16 = va16.mul(va16);
Vector<Short> norm2_16 = vb16.mul(vb16);
// sum into accumulators: 32-bit add
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
accSum = accSum.add(prod32);
accNorm1 = accNorm1.add(norm1_32);
accNorm2 = accNorm2.add(norm2_32);
}
// reduce
return new float[] {
accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD)
};
}
/** vectorized cosine body (128 bit vectors) */
private float[] cosineBody128(byte[] a, byte[] b, int limit) {
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
// process first half only: 16-bit multiply
Vector<Short> va16 = va8.convert(B2S, 0);
Vector<Short> vb16 = vb8.convert(B2S, 0);
Vector<Short> norm1_16 = va16.mul(va16);
Vector<Short> norm2_16 = vb16.mul(vb16);
Vector<Short> prod16 = va16.mul(vb16);
// sum into accumulators: 32-bit add
accNorm1 = accNorm1.add(norm1_16.convertShape(S2I, IntVector.SPECIES_128, 0));
accNorm2 = accNorm2.add(norm2_16.convertShape(S2I, IntVector.SPECIES_128, 0));
accSum = accSum.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
}
// reduce
return new float[] {
accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD)
};
}
@Override
public int squareDistance(byte[] a, byte[] b) {
int i = 0;