mirror of https://github.com/apache/lucene.git
cleanup cosine too, no perf impact
This commit is contained in:
parent
a72bf7ce68
commit
e93bd524cf
|
@ -127,71 +127,14 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
float sum = 0;
|
||||
float norm1 = 0;
|
||||
float norm2 = 0;
|
||||
|
||||
// if the array size is large (> 2x platform vector size), its worth the overhead to vectorize
|
||||
if (a.length > 2 * FLOAT_SPECIES.length()) {
|
||||
// vector loop is unrolled 4x (4 accumulators in parallel)
|
||||
// we don't know how many the cpu can do at once, some can do 2, some 4
|
||||
FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
int upperBound = FLOAT_SPECIES.loopBound(a.length - 3 * FLOAT_SPECIES.length());
|
||||
for (; i < upperBound; i += 4 * FLOAT_SPECIES.length()) {
|
||||
// one
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
|
||||
// two
|
||||
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
|
||||
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
|
||||
sum2 = sum2.add(vc.mul(vd));
|
||||
norm1_2 = norm1_2.add(vc.mul(vc));
|
||||
norm2_2 = norm2_2.add(vd.mul(vd));
|
||||
|
||||
// three
|
||||
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
|
||||
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
|
||||
sum3 = sum3.add(ve.mul(vf));
|
||||
norm1_3 = norm1_3.add(ve.mul(ve));
|
||||
norm2_3 = norm2_3.add(vf.mul(vf));
|
||||
|
||||
// four
|
||||
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
|
||||
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
|
||||
sum4 = sum4.add(vg.mul(vh));
|
||||
norm1_4 = norm1_4.add(vg.mul(vg));
|
||||
norm2_4 = norm2_4.add(vh.mul(vh));
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
upperBound = FLOAT_SPECIES.loopBound(a.length);
|
||||
for (; i < upperBound; i += FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
}
|
||||
// reduce
|
||||
FloatVector sumres1 = sum1.add(sum2);
|
||||
FloatVector sumres2 = sum3.add(sum4);
|
||||
FloatVector norm1res1 = norm1_1.add(norm1_2);
|
||||
FloatVector norm1res2 = norm1_3.add(norm1_4);
|
||||
FloatVector norm2res1 = norm2_1.add(norm2_2);
|
||||
FloatVector norm2res2 = norm2_3.add(norm2_4);
|
||||
sum += sumres1.add(sumres2).reduceLanes(ADD);
|
||||
norm1 += norm1res1.add(norm1res2).reduceLanes(ADD);
|
||||
norm2 += norm2res1.add(norm2res2).reduceLanes(ADD);
|
||||
i += FLOAT_SPECIES.loopBound(a.length);
|
||||
float[] ret = cosineBody(a, b, i);
|
||||
sum += ret[0];
|
||||
norm1 += ret[1];
|
||||
norm2 += ret[2];
|
||||
}
|
||||
|
||||
// scalar tail
|
||||
|
@ -205,6 +148,75 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
/** vectorized cosine body */
|
||||
private float[] cosineBody(float[] a, float[] b, int limit) {
|
||||
int i = 0;
|
||||
// vector loop is unrolled 4x (4 accumulators in parallel)
|
||||
// we don't know how many the cpu can do at once, some can do 2, some 4
|
||||
FloatVector sum1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector sum4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm1_4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_1 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_2 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_3 = FloatVector.zero(FLOAT_SPECIES);
|
||||
FloatVector norm2_4 = FloatVector.zero(FLOAT_SPECIES);
|
||||
int unrolledLimit = limit - 3 * FLOAT_SPECIES.length();
|
||||
for (; i < unrolledLimit; i += 4 * FLOAT_SPECIES.length()) {
|
||||
// one
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
|
||||
// two
|
||||
FloatVector vc = FloatVector.fromArray(FLOAT_SPECIES, a, i + FLOAT_SPECIES.length());
|
||||
FloatVector vd = FloatVector.fromArray(FLOAT_SPECIES, b, i + FLOAT_SPECIES.length());
|
||||
sum2 = sum2.add(vc.mul(vd));
|
||||
norm1_2 = norm1_2.add(vc.mul(vc));
|
||||
norm2_2 = norm2_2.add(vd.mul(vd));
|
||||
|
||||
// three
|
||||
FloatVector ve = FloatVector.fromArray(FLOAT_SPECIES, a, i + 2 * FLOAT_SPECIES.length());
|
||||
FloatVector vf = FloatVector.fromArray(FLOAT_SPECIES, b, i + 2 * FLOAT_SPECIES.length());
|
||||
sum3 = sum3.add(ve.mul(vf));
|
||||
norm1_3 = norm1_3.add(ve.mul(ve));
|
||||
norm2_3 = norm2_3.add(vf.mul(vf));
|
||||
|
||||
// four
|
||||
FloatVector vg = FloatVector.fromArray(FLOAT_SPECIES, a, i + 3 * FLOAT_SPECIES.length());
|
||||
FloatVector vh = FloatVector.fromArray(FLOAT_SPECIES, b, i + 3 * FLOAT_SPECIES.length());
|
||||
sum4 = sum4.add(vg.mul(vh));
|
||||
norm1_4 = norm1_4.add(vg.mul(vg));
|
||||
norm2_4 = norm2_4.add(vh.mul(vh));
|
||||
}
|
||||
// vector tail: less scalar computations for unaligned sizes, esp with big vector sizes
|
||||
for (; i < limit; i += FLOAT_SPECIES.length()) {
|
||||
FloatVector va = FloatVector.fromArray(FLOAT_SPECIES, a, i);
|
||||
FloatVector vb = FloatVector.fromArray(FLOAT_SPECIES, b, i);
|
||||
sum1 = sum1.add(va.mul(vb));
|
||||
norm1_1 = norm1_1.add(va.mul(va));
|
||||
norm2_1 = norm2_1.add(vb.mul(vb));
|
||||
}
|
||||
// reduce
|
||||
FloatVector sumres1 = sum1.add(sum2);
|
||||
FloatVector sumres2 = sum3.add(sum4);
|
||||
FloatVector norm1res1 = norm1_1.add(norm1_2);
|
||||
FloatVector norm1res2 = norm1_3.add(norm1_4);
|
||||
FloatVector norm2res1 = norm2_1.add(norm2_2);
|
||||
FloatVector norm2res2 = norm2_3.add(norm2_4);
|
||||
return new float[] {
|
||||
sumres1.add(sumres2).reduceLanes(ADD),
|
||||
norm1res1.add(norm1res2).reduceLanes(ADD),
|
||||
norm2res1.add(norm2res2).reduceLanes(ADD)
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public float squareDistance(float[] a, float[] b) {
|
||||
int i = 0;
|
||||
|
@ -357,65 +369,22 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
int sum = 0;
|
||||
int norm1 = 0;
|
||||
int norm2 = 0;
|
||||
|
||||
// only vectorize if we'll at least enter the loop a single time, and we have at least 128-bit
|
||||
// vectors (256-bit on intel to dodge performance landmines)
|
||||
if (a.length >= 16 && useIntegerVectors) {
|
||||
final float[] ret;
|
||||
if (INT_SPECIES_PREF_BIT_SIZE >= 256) {
|
||||
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
|
||||
int upperBound = BYTE_SPECIES.loopBound(a.length);
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (; i < upperBound; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 16-bit multiply
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
|
||||
// sum into accumulators: 32-bit add
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
accSum = accSum.add(prod32);
|
||||
accNorm1 = accNorm1.add(norm1_32);
|
||||
accNorm2 = accNorm2.add(norm2_32);
|
||||
}
|
||||
// reduce
|
||||
sum += accSum.reduceLanes(ADD);
|
||||
norm1 += accNorm1.reduceLanes(ADD);
|
||||
norm2 += accNorm2.reduceLanes(ADD);
|
||||
i += BYTE_SPECIES.loopBound(a.length);
|
||||
ret = cosineBody256(a, b, i);
|
||||
} else {
|
||||
// 128-bit impl, which is tricky since we don't have SPECIES_32, it does "overlapping read"
|
||||
int upperBound = ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (; i < upperBound; i += ByteVector.SPECIES_64.length() >> 1) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
|
||||
// process first half only: 16-bit multiply
|
||||
Vector<Short> va16 = va8.convert(B2S, 0);
|
||||
Vector<Short> vb16 = vb8.convert(B2S, 0);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// sum into accumulators: 32-bit add
|
||||
accNorm1 = accNorm1.add(norm1_16.convertShape(S2I, IntVector.SPECIES_128, 0));
|
||||
accNorm2 = accNorm2.add(norm2_16.convertShape(S2I, IntVector.SPECIES_128, 0));
|
||||
accSum = accSum.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
|
||||
}
|
||||
// reduce
|
||||
sum += accSum.reduceLanes(ADD);
|
||||
norm1 += accNorm1.reduceLanes(ADD);
|
||||
norm2 += accNorm2.reduceLanes(ADD);
|
||||
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
|
||||
i += ByteVector.SPECIES_64.loopBound(a.length - ByteVector.SPECIES_64.length());
|
||||
ret = cosineBody128(a, b, i);
|
||||
}
|
||||
sum += ret[0];
|
||||
norm1 += ret[1];
|
||||
norm2 += ret[2];
|
||||
}
|
||||
|
||||
// scalar tail
|
||||
|
@ -429,6 +398,64 @@ final class PanamaVectorUtilSupport implements VectorUtilSupport {
|
|||
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||
}
|
||||
|
||||
/** vectorized cosine body (256+ bit vectors) */
|
||||
private float[] cosineBody256(byte[] a, byte[] b, int limit) {
|
||||
// optimized 256/512 bit implementation, processes 8/16 bytes at a time
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_PREFERRED);
|
||||
for (int i = 0; i < limit; i += BYTE_SPECIES.length()) {
|
||||
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(BYTE_SPECIES, b, i);
|
||||
|
||||
// 16-bit multiply
|
||||
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES, 0);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
|
||||
// sum into accumulators: 32-bit add
|
||||
Vector<Integer> prod32 = prod16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm1_32 = norm1_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
Vector<Integer> norm2_32 = norm2_16.convertShape(S2I, IntVector.SPECIES_PREFERRED, 0);
|
||||
accSum = accSum.add(prod32);
|
||||
accNorm1 = accNorm1.add(norm1_32);
|
||||
accNorm2 = accNorm2.add(norm2_32);
|
||||
}
|
||||
// reduce
|
||||
return new float[] {
|
||||
accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD)
|
||||
};
|
||||
}
|
||||
|
||||
/** vectorized cosine body (128 bit vectors) */
|
||||
private float[] cosineBody128(byte[] a, byte[] b, int limit) {
|
||||
IntVector accSum = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm1 = IntVector.zero(IntVector.SPECIES_128);
|
||||
IntVector accNorm2 = IntVector.zero(IntVector.SPECIES_128);
|
||||
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
|
||||
ByteVector va8 = ByteVector.fromArray(ByteVector.SPECIES_64, a, i);
|
||||
ByteVector vb8 = ByteVector.fromArray(ByteVector.SPECIES_64, b, i);
|
||||
|
||||
// process first half only: 16-bit multiply
|
||||
Vector<Short> va16 = va8.convert(B2S, 0);
|
||||
Vector<Short> vb16 = vb8.convert(B2S, 0);
|
||||
Vector<Short> norm1_16 = va16.mul(va16);
|
||||
Vector<Short> norm2_16 = vb16.mul(vb16);
|
||||
Vector<Short> prod16 = va16.mul(vb16);
|
||||
|
||||
// sum into accumulators: 32-bit add
|
||||
accNorm1 = accNorm1.add(norm1_16.convertShape(S2I, IntVector.SPECIES_128, 0));
|
||||
accNorm2 = accNorm2.add(norm2_16.convertShape(S2I, IntVector.SPECIES_128, 0));
|
||||
accSum = accSum.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
|
||||
}
|
||||
// reduce
|
||||
return new float[] {
|
||||
accSum.reduceLanes(ADD), accNorm1.reduceLanes(ADD), accNorm2.reduceLanes(ADD)
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int squareDistance(byte[] a, byte[] b) {
|
||||
int i = 0;
|
||||
|
|
Loading…
Reference in New Issue