mirror of https://github.com/apache/lucene.git
LUCENE-9636: Exact and operation to get a SIMD optimize (#2139)
Co-authored-by: 郭峰 <guofeng.my@bytedance.com>
This commit is contained in:
parent
bc854b2627
commit
ecd47a8b7b
|
@ -663,10 +663,11 @@ final class ForUtil {
|
|||
private static void decode6(DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||
in.readLELongs(tmp, 0, 12);
|
||||
shiftLongs(tmp, 12, longs, 0, 2, MASK8_6);
|
||||
shiftLongs(tmp, 12, tmp, 0, 0, MASK8_2);
|
||||
for (int iter = 0, tmpIdx = 0, longsIdx = 12; iter < 4; ++iter, tmpIdx += 3, longsIdx += 1) {
|
||||
long l0 = (tmp[tmpIdx+0] & MASK8_2) << 4;
|
||||
l0 |= (tmp[tmpIdx+1] & MASK8_2) << 2;
|
||||
l0 |= (tmp[tmpIdx+2] & MASK8_2) << 0;
|
||||
long l0 = tmp[tmpIdx+0] << 4;
|
||||
l0 |= tmp[tmpIdx+1] << 2;
|
||||
l0 |= tmp[tmpIdx+2] << 0;
|
||||
longs[longsIdx+0] = l0;
|
||||
}
|
||||
}
|
||||
|
@ -674,14 +675,15 @@ final class ForUtil {
|
|||
private static void decode7(DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||
in.readLELongs(tmp, 0, 14);
|
||||
shiftLongs(tmp, 14, longs, 0, 1, MASK8_7);
|
||||
shiftLongs(tmp, 14, tmp, 0, 0, MASK8_1);
|
||||
for (int iter = 0, tmpIdx = 0, longsIdx = 14; iter < 2; ++iter, tmpIdx += 7, longsIdx += 1) {
|
||||
long l0 = (tmp[tmpIdx+0] & MASK8_1) << 6;
|
||||
l0 |= (tmp[tmpIdx+1] & MASK8_1) << 5;
|
||||
l0 |= (tmp[tmpIdx+2] & MASK8_1) << 4;
|
||||
l0 |= (tmp[tmpIdx+3] & MASK8_1) << 3;
|
||||
l0 |= (tmp[tmpIdx+4] & MASK8_1) << 2;
|
||||
l0 |= (tmp[tmpIdx+5] & MASK8_1) << 1;
|
||||
l0 |= (tmp[tmpIdx+6] & MASK8_1) << 0;
|
||||
long l0 = tmp[tmpIdx+0] << 6;
|
||||
l0 |= tmp[tmpIdx+1] << 5;
|
||||
l0 |= tmp[tmpIdx+2] << 4;
|
||||
l0 |= tmp[tmpIdx+3] << 3;
|
||||
l0 |= tmp[tmpIdx+4] << 2;
|
||||
l0 |= tmp[tmpIdx+5] << 1;
|
||||
l0 |= tmp[tmpIdx+6] << 0;
|
||||
longs[longsIdx+0] = l0;
|
||||
}
|
||||
}
|
||||
|
@ -766,10 +768,11 @@ final class ForUtil {
|
|||
private static void decode12(DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||
in.readLELongs(tmp, 0, 24);
|
||||
shiftLongs(tmp, 24, longs, 0, 4, MASK16_12);
|
||||
shiftLongs(tmp, 24, tmp, 0, 0, MASK16_4);
|
||||
for (int iter = 0, tmpIdx = 0, longsIdx = 24; iter < 8; ++iter, tmpIdx += 3, longsIdx += 1) {
|
||||
long l0 = (tmp[tmpIdx+0] & MASK16_4) << 8;
|
||||
l0 |= (tmp[tmpIdx+1] & MASK16_4) << 4;
|
||||
l0 |= (tmp[tmpIdx+2] & MASK16_4) << 0;
|
||||
long l0 = tmp[tmpIdx+0] << 8;
|
||||
l0 |= tmp[tmpIdx+1] << 4;
|
||||
l0 |= tmp[tmpIdx+2] << 0;
|
||||
longs[longsIdx+0] = l0;
|
||||
}
|
||||
}
|
||||
|
@ -802,14 +805,15 @@ final class ForUtil {
|
|||
private static void decode14(DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||
in.readLELongs(tmp, 0, 28);
|
||||
shiftLongs(tmp, 28, longs, 0, 2, MASK16_14);
|
||||
shiftLongs(tmp, 28, tmp, 0, 0, MASK16_2);
|
||||
for (int iter = 0, tmpIdx = 0, longsIdx = 28; iter < 4; ++iter, tmpIdx += 7, longsIdx += 1) {
|
||||
long l0 = (tmp[tmpIdx+0] & MASK16_2) << 12;
|
||||
l0 |= (tmp[tmpIdx+1] & MASK16_2) << 10;
|
||||
l0 |= (tmp[tmpIdx+2] & MASK16_2) << 8;
|
||||
l0 |= (tmp[tmpIdx+3] & MASK16_2) << 6;
|
||||
l0 |= (tmp[tmpIdx+4] & MASK16_2) << 4;
|
||||
l0 |= (tmp[tmpIdx+5] & MASK16_2) << 2;
|
||||
l0 |= (tmp[tmpIdx+6] & MASK16_2) << 0;
|
||||
long l0 = tmp[tmpIdx+0] << 12;
|
||||
l0 |= tmp[tmpIdx+1] << 10;
|
||||
l0 |= tmp[tmpIdx+2] << 8;
|
||||
l0 |= tmp[tmpIdx+3] << 6;
|
||||
l0 |= tmp[tmpIdx+4] << 4;
|
||||
l0 |= tmp[tmpIdx+5] << 2;
|
||||
l0 |= tmp[tmpIdx+6] << 0;
|
||||
longs[longsIdx+0] = l0;
|
||||
}
|
||||
}
|
||||
|
@ -817,22 +821,23 @@ final class ForUtil {
|
|||
private static void decode15(DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||
in.readLELongs(tmp, 0, 30);
|
||||
shiftLongs(tmp, 30, longs, 0, 1, MASK16_15);
|
||||
shiftLongs(tmp, 30, tmp, 0, 0, MASK16_1);
|
||||
for (int iter = 0, tmpIdx = 0, longsIdx = 30; iter < 2; ++iter, tmpIdx += 15, longsIdx += 1) {
|
||||
long l0 = (tmp[tmpIdx+0] & MASK16_1) << 14;
|
||||
l0 |= (tmp[tmpIdx+1] & MASK16_1) << 13;
|
||||
l0 |= (tmp[tmpIdx+2] & MASK16_1) << 12;
|
||||
l0 |= (tmp[tmpIdx+3] & MASK16_1) << 11;
|
||||
l0 |= (tmp[tmpIdx+4] & MASK16_1) << 10;
|
||||
l0 |= (tmp[tmpIdx+5] & MASK16_1) << 9;
|
||||
l0 |= (tmp[tmpIdx+6] & MASK16_1) << 8;
|
||||
l0 |= (tmp[tmpIdx+7] & MASK16_1) << 7;
|
||||
l0 |= (tmp[tmpIdx+8] & MASK16_1) << 6;
|
||||
l0 |= (tmp[tmpIdx+9] & MASK16_1) << 5;
|
||||
l0 |= (tmp[tmpIdx+10] & MASK16_1) << 4;
|
||||
l0 |= (tmp[tmpIdx+11] & MASK16_1) << 3;
|
||||
l0 |= (tmp[tmpIdx+12] & MASK16_1) << 2;
|
||||
l0 |= (tmp[tmpIdx+13] & MASK16_1) << 1;
|
||||
l0 |= (tmp[tmpIdx+14] & MASK16_1) << 0;
|
||||
long l0 = tmp[tmpIdx+0] << 14;
|
||||
l0 |= tmp[tmpIdx+1] << 13;
|
||||
l0 |= tmp[tmpIdx+2] << 12;
|
||||
l0 |= tmp[tmpIdx+3] << 11;
|
||||
l0 |= tmp[tmpIdx+4] << 10;
|
||||
l0 |= tmp[tmpIdx+5] << 9;
|
||||
l0 |= tmp[tmpIdx+6] << 8;
|
||||
l0 |= tmp[tmpIdx+7] << 7;
|
||||
l0 |= tmp[tmpIdx+8] << 6;
|
||||
l0 |= tmp[tmpIdx+9] << 5;
|
||||
l0 |= tmp[tmpIdx+10] << 4;
|
||||
l0 |= tmp[tmpIdx+11] << 3;
|
||||
l0 |= tmp[tmpIdx+12] << 2;
|
||||
l0 |= tmp[tmpIdx+13] << 1;
|
||||
l0 |= tmp[tmpIdx+14] << 0;
|
||||
longs[longsIdx+0] = l0;
|
||||
}
|
||||
}
|
||||
|
@ -1117,10 +1122,11 @@ final class ForUtil {
|
|||
private static void decode24(DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||
in.readLELongs(tmp, 0, 48);
|
||||
shiftLongs(tmp, 48, longs, 0, 8, MASK32_24);
|
||||
shiftLongs(tmp, 48, tmp, 0, 0, MASK32_8);
|
||||
for (int iter = 0, tmpIdx = 0, longsIdx = 48; iter < 16; ++iter, tmpIdx += 3, longsIdx += 1) {
|
||||
long l0 = (tmp[tmpIdx+0] & MASK32_8) << 16;
|
||||
l0 |= (tmp[tmpIdx+1] & MASK32_8) << 8;
|
||||
l0 |= (tmp[tmpIdx+2] & MASK32_8) << 0;
|
||||
long l0 = tmp[tmpIdx+0] << 16;
|
||||
l0 |= tmp[tmpIdx+1] << 8;
|
||||
l0 |= tmp[tmpIdx+2] << 0;
|
||||
longs[longsIdx+0] = l0;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -366,6 +366,29 @@ final class ForUtil {
|
|||
|
||||
"""
|
||||
|
||||
def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long, o, num_values, f):
|
||||
iteration = 1
|
||||
num_longs = bpv * num_values / remaining_bits_per_long
|
||||
while num_longs % 2 == 0 and num_values % 2 == 0:
|
||||
num_longs /= 2
|
||||
num_values /= 2
|
||||
iteration *= 2
|
||||
|
||||
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
|
||||
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
|
||||
tmp_idx = 0
|
||||
b = bpv
|
||||
b -= remaining_bits_per_long
|
||||
f.write(' long l0 = tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b))
|
||||
tmp_idx += 1
|
||||
while b >= remaining_bits_per_long:
|
||||
b -= remaining_bits_per_long
|
||||
f.write(' l0 |= tmp[tmpIdx+%d] << %d;\n' %(tmp_idx, b))
|
||||
tmp_idx += 1
|
||||
f.write(' longs[longsIdx+0] = l0;\n')
|
||||
f.write(' }\n')
|
||||
|
||||
|
||||
def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values, f):
|
||||
iteration = 1
|
||||
num_longs = bpv * num_values / remaining_bits_per_long
|
||||
|
@ -417,7 +440,10 @@ def writeDecode(bpv, f):
|
|||
o += bpv*2
|
||||
shift -= bpv
|
||||
if shift + bpv > 0:
|
||||
writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
|
||||
if bpv % (next_primitive % bpv) == 0:
|
||||
writeRemainderWithSIMDOptimize(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
|
||||
else:
|
||||
writeRemainder(bpv, next_primitive, shift + bpv, o, 128/num_values_per_long - o, f)
|
||||
f.write(' }\n')
|
||||
f.write('\n')
|
||||
|
||||
|
|
Loading…
Reference in New Issue