LUCENE-9636: Exact and operation to get a SIMD optimize (#2139)

Co-authored-by: 郭峰 <guofeng.my@bytedance.com>
This commit is contained in:
gf2121 2020-12-14 20:37:50 +08:00 committed by GitHub
parent bc854b2627
commit ecd47a8b7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 39 deletions

View File

@ -663,10 +663,11 @@ final class ForUtil {
private static void decode6(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 12);
shiftLongs(tmp, 12, longs, 0, 2, MASK8_6);
shiftLongs(tmp, 12, tmp, 0, 0, MASK8_2);
for (int iter = 0, tmpIdx = 0, longsIdx = 12; iter < 4; ++iter, tmpIdx += 3, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK8_2) << 4;
l0 |= (tmp[tmpIdx+1] & MASK8_2) << 2;
l0 |= (tmp[tmpIdx+2] & MASK8_2) << 0;
long l0 = tmp[tmpIdx+0] << 4;
l0 |= tmp[tmpIdx+1] << 2;
l0 |= tmp[tmpIdx+2] << 0;
longs[longsIdx+0] = l0;
}
}
@ -674,14 +675,15 @@ final class ForUtil {
private static void decode7(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 14);
shiftLongs(tmp, 14, longs, 0, 1, MASK8_7);
shiftLongs(tmp, 14, tmp, 0, 0, MASK8_1);
for (int iter = 0, tmpIdx = 0, longsIdx = 14; iter < 2; ++iter, tmpIdx += 7, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK8_1) << 6;
l0 |= (tmp[tmpIdx+1] & MASK8_1) << 5;
l0 |= (tmp[tmpIdx+2] & MASK8_1) << 4;
l0 |= (tmp[tmpIdx+3] & MASK8_1) << 3;
l0 |= (tmp[tmpIdx+4] & MASK8_1) << 2;
l0 |= (tmp[tmpIdx+5] & MASK8_1) << 1;
l0 |= (tmp[tmpIdx+6] & MASK8_1) << 0;
long l0 = tmp[tmpIdx+0] << 6;
l0 |= tmp[tmpIdx+1] << 5;
l0 |= tmp[tmpIdx+2] << 4;
l0 |= tmp[tmpIdx+3] << 3;
l0 |= tmp[tmpIdx+4] << 2;
l0 |= tmp[tmpIdx+5] << 1;
l0 |= tmp[tmpIdx+6] << 0;
longs[longsIdx+0] = l0;
}
}
@ -766,10 +768,11 @@ final class ForUtil {
private static void decode12(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 24);
shiftLongs(tmp, 24, longs, 0, 4, MASK16_12);
shiftLongs(tmp, 24, tmp, 0, 0, MASK16_4);
for (int iter = 0, tmpIdx = 0, longsIdx = 24; iter < 8; ++iter, tmpIdx += 3, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK16_4) << 8;
l0 |= (tmp[tmpIdx+1] & MASK16_4) << 4;
l0 |= (tmp[tmpIdx+2] & MASK16_4) << 0;
long l0 = tmp[tmpIdx+0] << 8;
l0 |= tmp[tmpIdx+1] << 4;
l0 |= tmp[tmpIdx+2] << 0;
longs[longsIdx+0] = l0;
}
}
@ -802,14 +805,15 @@ final class ForUtil {
private static void decode14(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 28);
shiftLongs(tmp, 28, longs, 0, 2, MASK16_14);
shiftLongs(tmp, 28, tmp, 0, 0, MASK16_2);
for (int iter = 0, tmpIdx = 0, longsIdx = 28; iter < 4; ++iter, tmpIdx += 7, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK16_2) << 12;
l0 |= (tmp[tmpIdx+1] & MASK16_2) << 10;
l0 |= (tmp[tmpIdx+2] & MASK16_2) << 8;
l0 |= (tmp[tmpIdx+3] & MASK16_2) << 6;
l0 |= (tmp[tmpIdx+4] & MASK16_2) << 4;
l0 |= (tmp[tmpIdx+5] & MASK16_2) << 2;
l0 |= (tmp[tmpIdx+6] & MASK16_2) << 0;
long l0 = tmp[tmpIdx+0] << 12;
l0 |= tmp[tmpIdx+1] << 10;
l0 |= tmp[tmpIdx+2] << 8;
l0 |= tmp[tmpIdx+3] << 6;
l0 |= tmp[tmpIdx+4] << 4;
l0 |= tmp[tmpIdx+5] << 2;
l0 |= tmp[tmpIdx+6] << 0;
longs[longsIdx+0] = l0;
}
}
@ -817,22 +821,23 @@ final class ForUtil {
private static void decode15(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 30);
shiftLongs(tmp, 30, longs, 0, 1, MASK16_15);
shiftLongs(tmp, 30, tmp, 0, 0, MASK16_1);
for (int iter = 0, tmpIdx = 0, longsIdx = 30; iter < 2; ++iter, tmpIdx += 15, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK16_1) << 14;
l0 |= (tmp[tmpIdx+1] & MASK16_1) << 13;
l0 |= (tmp[tmpIdx+2] & MASK16_1) << 12;
l0 |= (tmp[tmpIdx+3] & MASK16_1) << 11;
l0 |= (tmp[tmpIdx+4] & MASK16_1) << 10;
l0 |= (tmp[tmpIdx+5] & MASK16_1) << 9;
l0 |= (tmp[tmpIdx+6] & MASK16_1) << 8;
l0 |= (tmp[tmpIdx+7] & MASK16_1) << 7;
l0 |= (tmp[tmpIdx+8] & MASK16_1) << 6;
l0 |= (tmp[tmpIdx+9] & MASK16_1) << 5;
l0 |= (tmp[tmpIdx+10] & MASK16_1) << 4;
l0 |= (tmp[tmpIdx+11] & MASK16_1) << 3;
l0 |= (tmp[tmpIdx+12] & MASK16_1) << 2;
l0 |= (tmp[tmpIdx+13] & MASK16_1) << 1;
l0 |= (tmp[tmpIdx+14] & MASK16_1) << 0;
long l0 = tmp[tmpIdx+0] << 14;
l0 |= tmp[tmpIdx+1] << 13;
l0 |= tmp[tmpIdx+2] << 12;
l0 |= tmp[tmpIdx+3] << 11;
l0 |= tmp[tmpIdx+4] << 10;
l0 |= tmp[tmpIdx+5] << 9;
l0 |= tmp[tmpIdx+6] << 8;
l0 |= tmp[tmpIdx+7] << 7;
l0 |= tmp[tmpIdx+8] << 6;
l0 |= tmp[tmpIdx+9] << 5;
l0 |= tmp[tmpIdx+10] << 4;
l0 |= tmp[tmpIdx+11] << 3;
l0 |= tmp[tmpIdx+12] << 2;
l0 |= tmp[tmpIdx+13] << 1;
l0 |= tmp[tmpIdx+14] << 0;
longs[longsIdx+0] = l0;
}
}
@ -1117,10 +1122,11 @@ final class ForUtil {
private static void decode24(DataInput in, long[] tmp, long[] longs) throws IOException {
in.readLELongs(tmp, 0, 48);
shiftLongs(tmp, 48, longs, 0, 8, MASK32_24);
shiftLongs(tmp, 48, tmp, 0, 0, MASK32_8);
for (int iter = 0, tmpIdx = 0, longsIdx = 48; iter < 16; ++iter, tmpIdx += 3, longsIdx += 1) {
long l0 = (tmp[tmpIdx+0] & MASK32_8) << 16;
l0 |= (tmp[tmpIdx+1] & MASK32_8) << 8;
l0 |= (tmp[tmpIdx+2] & MASK32_8) << 0;
long l0 = tmp[tmpIdx+0] << 16;
l0 |= tmp[tmpIdx+1] << 8;
l0 |= tmp[tmpIdx+2] << 0;
longs[longsIdx+0] = l0;
}
}

View File

@ -366,6 +366,29 @@ final class ForUtil {
"""
def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long, o, num_values, f):
iteration = 1
num_longs = bpv * num_values / remaining_bits_per_long
while num_longs % 2 == 0 and num_values % 2 == 0:
num_longs /= 2
num_values /= 2
iteration *= 2
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
tmp_idx = 0
b = bpv
b -= remaining_bits_per_long
f.write(' long l0 = tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b))
tmp_idx += 1
while b >= remaining_bits_per_long:
b -= remaining_bits_per_long
f.write(' l0 |= tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b))
tmp_idx += 1
f.write(' longs[longsIdx+0] = l0;\n')
f.write(' }\n')
def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values, f):
iteration = 1
num_longs = bpv * num_values / remaining_bits_per_long
@ -417,7 +440,10 @@ def writeDecode(bpv, f):
o += bpv*2
shift -= bpv
if shift + bpv > 0:
writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
if bpv % (next_primitive % bpv) == 0:
writeRemainderWithSIMDOptimize(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
else:
writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
f.write(' }\n')
f.write('\n')