LUCENE-9629: Use computed masks (#2113)

Co-authored-by: 郭峰 <guofeng.my@bytedance.com>
This commit is contained in:
gf2121 2020-12-18 08:59:40 -06:00 committed by GitHub
parent 894b6b5c88
commit 00a5637457
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 74 deletions

View File

@ -250,17 +250,17 @@ final class ForUtil {
final int remainingBitsPerLong = shift + bitsPerValue;
final long maskRemainingBitsPerLong;
if (nextPrimitive == 8) {
maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
} else if (nextPrimitive == 16) {
maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
} else {
maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
}
int tmpIdx = 0;
int remainingBitsPerValue = bitsPerValue;
while (idx < numLongs) {
if (remainingBitsPerValue > remainingBitsPerLong) {
if (remainingBitsPerValue >= remainingBitsPerLong) {
remainingBitsPerValue -= remainingBitsPerLong;
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
if (remainingBitsPerValue == 0) {
@ -270,14 +270,14 @@ final class ForUtil {
} else {
final long mask1, mask2;
if (nextPrimitive == 8) {
mask1 = mask8(remainingBitsPerValue);
mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
mask1 = MASKS8[remainingBitsPerValue];
mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
} else if (nextPrimitive == 16) {
mask1 = mask16(remainingBitsPerValue);
mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
mask1 = MASKS16[remainingBitsPerValue];
mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
} else {
mask1 = mask32(remainingBitsPerValue);
mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
mask1 = MASKS32[remainingBitsPerValue];
mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
}
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@ -302,7 +302,7 @@ final class ForUtil {
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
final int numLongs = bitsPerValue << 1;
in.readLELongs(tmp, 0, numLongs);
final long mask = mask32(bitsPerValue);
final long mask = MASKS32[bitsPerValue];
int longsIdx = 0;
int shift = 32 - bitsPerValue;
for (; shift >= 0; shift -= bitsPerValue) {
@ -310,18 +310,18 @@ final class ForUtil {
longsIdx += numLongs;
}
final int remainingBitsPerLong = shift + bitsPerValue;
final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
int tmpIdx = 0;
int remainingBits = remainingBitsPerLong;
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
int b = bitsPerValue - remainingBits;
long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
while (b >= remainingBitsPerLong) {
b -= remainingBitsPerLong;
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
}
if (b > 0) {
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
remainingBits = remainingBitsPerLong - b;
} else {
remainingBits = remainingBitsPerLong;
@ -341,50 +341,65 @@ final class ForUtil {
}
}
private static final long MASK8_1 = mask8(1);
private static final long MASK8_2 = mask8(2);
private static final long MASK8_3 = mask8(3);
private static final long MASK8_4 = mask8(4);
private static final long MASK8_5 = mask8(5);
private static final long MASK8_6 = mask8(6);
private static final long MASK8_7 = mask8(7);
private static final long MASK16_1 = mask16(1);
private static final long MASK16_2 = mask16(2);
private static final long MASK16_3 = mask16(3);
private static final long MASK16_4 = mask16(4);
private static final long MASK16_5 = mask16(5);
private static final long MASK16_6 = mask16(6);
private static final long MASK16_7 = mask16(7);
private static final long MASK16_9 = mask16(9);
private static final long MASK16_10 = mask16(10);
private static final long MASK16_11 = mask16(11);
private static final long MASK16_12 = mask16(12);
private static final long MASK16_13 = mask16(13);
private static final long MASK16_14 = mask16(14);
private static final long MASK16_15 = mask16(15);
private static final long MASK32_1 = mask32(1);
private static final long MASK32_2 = mask32(2);
private static final long MASK32_3 = mask32(3);
private static final long MASK32_4 = mask32(4);
private static final long MASK32_5 = mask32(5);
private static final long MASK32_6 = mask32(6);
private static final long MASK32_7 = mask32(7);
private static final long MASK32_8 = mask32(8);
private static final long MASK32_9 = mask32(9);
private static final long MASK32_10 = mask32(10);
private static final long MASK32_11 = mask32(11);
private static final long MASK32_12 = mask32(12);
private static final long MASK32_13 = mask32(13);
private static final long MASK32_14 = mask32(14);
private static final long MASK32_15 = mask32(15);
private static final long MASK32_17 = mask32(17);
private static final long MASK32_18 = mask32(18);
private static final long MASK32_19 = mask32(19);
private static final long MASK32_20 = mask32(20);
private static final long MASK32_21 = mask32(21);
private static final long MASK32_22 = mask32(22);
private static final long MASK32_23 = mask32(23);
private static final long MASK32_24 = mask32(24);
private static final long[] MASKS8 = new long[8];
private static final long[] MASKS16 = new long[16];
private static final long[] MASKS32 = new long[32];
static {
for (int i = 0; i < 8; ++i) {
MASKS8[i] = mask8(i);
}
for (int i = 0; i < 16; ++i) {
MASKS16[i] = mask16(i);
}
for (int i = 0; i < 32; ++i) {
MASKS32[i] = mask32(i);
}
}
//mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable
private static final long MASK8_1 = MASKS8[1];
private static final long MASK8_2 = MASKS8[2];
private static final long MASK8_3 = MASKS8[3];
private static final long MASK8_4 = MASKS8[4];
private static final long MASK8_5 = MASKS8[5];
private static final long MASK8_6 = MASKS8[6];
private static final long MASK8_7 = MASKS8[7];
private static final long MASK16_1 = MASKS16[1];
private static final long MASK16_2 = MASKS16[2];
private static final long MASK16_3 = MASKS16[3];
private static final long MASK16_4 = MASKS16[4];
private static final long MASK16_5 = MASKS16[5];
private static final long MASK16_6 = MASKS16[6];
private static final long MASK16_7 = MASKS16[7];
private static final long MASK16_9 = MASKS16[9];
private static final long MASK16_10 = MASKS16[10];
private static final long MASK16_11 = MASKS16[11];
private static final long MASK16_12 = MASKS16[12];
private static final long MASK16_13 = MASKS16[13];
private static final long MASK16_14 = MASKS16[14];
private static final long MASK16_15 = MASKS16[15];
private static final long MASK32_1 = MASKS32[1];
private static final long MASK32_2 = MASKS32[2];
private static final long MASK32_3 = MASKS32[3];
private static final long MASK32_4 = MASKS32[4];
private static final long MASK32_5 = MASKS32[5];
private static final long MASK32_6 = MASKS32[6];
private static final long MASK32_7 = MASKS32[7];
private static final long MASK32_8 = MASKS32[8];
private static final long MASK32_9 = MASKS32[9];
private static final long MASK32_10 = MASKS32[10];
private static final long MASK32_11 = MASKS32[11];
private static final long MASK32_12 = MASKS32[12];
private static final long MASK32_13 = MASKS32[13];
private static final long MASK32_14 = MASKS32[14];
private static final long MASK32_15 = MASKS32[15];
private static final long MASK32_17 = MASKS32[17];
private static final long MASK32_18 = MASKS32[18];
private static final long MASK32_19 = MASKS32[19];
private static final long MASK32_20 = MASKS32[20];
private static final long MASK32_21 = MASKS32[21];
private static final long MASK32_22 = MASKS32[22];
private static final long MASK32_23 = MASKS32[23];
private static final long MASK32_24 = MASKS32[24];
/**

View File

@ -21,6 +21,7 @@ from fractions import gcd
MAX_SPECIALIZED_BITS_PER_VALUE = 24
OUTPUT_FILE = "ForUtil.java"
PRIMITIVE_SIZE = [8, 16, 32]
HEADER = """// This file has been automatically generated, DO NOT EDIT
/*
@ -273,17 +274,17 @@ final class ForUtil {
final int remainingBitsPerLong = shift + bitsPerValue;
final long maskRemainingBitsPerLong;
if (nextPrimitive == 8) {
maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
} else if (nextPrimitive == 16) {
maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
} else {
maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
}
int tmpIdx = 0;
int remainingBitsPerValue = bitsPerValue;
while (idx < numLongs) {
if (remainingBitsPerValue > remainingBitsPerLong) {
if (remainingBitsPerValue >= remainingBitsPerLong) {
remainingBitsPerValue -= remainingBitsPerLong;
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
if (remainingBitsPerValue == 0) {
@ -293,14 +294,14 @@ final class ForUtil {
} else {
final long mask1, mask2;
if (nextPrimitive == 8) {
mask1 = mask8(remainingBitsPerValue);
mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
mask1 = MASKS8[remainingBitsPerValue];
mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
} else if (nextPrimitive == 16) {
mask1 = mask16(remainingBitsPerValue);
mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
mask1 = MASKS16[remainingBitsPerValue];
mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
} else {
mask1 = mask32(remainingBitsPerValue);
mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
mask1 = MASKS32[remainingBitsPerValue];
mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
}
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@ -325,7 +326,7 @@ final class ForUtil {
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
final int numLongs = bitsPerValue << 1;
in.readLELongs(tmp, 0, numLongs);
final long mask = mask32(bitsPerValue);
final long mask = MASKS32[bitsPerValue];
int longsIdx = 0;
int shift = 32 - bitsPerValue;
for (; shift >= 0; shift -= bitsPerValue) {
@ -333,18 +334,18 @@ final class ForUtil {
longsIdx += numLongs;
}
final int remainingBitsPerLong = shift + bitsPerValue;
final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
int tmpIdx = 0;
int remainingBits = remainingBitsPerLong;
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
int b = bitsPerValue - remainingBits;
long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
while (b >= remainingBitsPerLong) {
b -= remainingBitsPerLong;
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
}
if (b > 0) {
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
remainingBits = remainingBitsPerLong - b;
} else {
remainingBits = remainingBitsPerLong;
@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long,
num_values /= 2
iteration *= 2
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
tmp_idx = 0
@ -450,11 +452,21 @@ def writeDecode(bpv, f):
if __name__ == '__main__':
f = open(OUTPUT_FILE, 'w')
f.write(HEADER)
for primitive_size in [8, 16, 32]:
for primitive_size in PRIMITIVE_SIZE:
f.write(' private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size))
f.write(' static {\n')
for primitive_size in PRIMITIVE_SIZE:
f.write(' for (int i = 0; i < %d; ++i) {\n' %primitive_size)
f.write(' MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size))
f.write(' }\n')
f.write(' }\n')
f.write(' //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n')
for primitive_size in PRIMITIVE_SIZE:
for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
if bpv * 2 != primitive_size or primitive_size == 8:
f.write(' private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv))
f.write(' private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv))
f.write('\n')
f.write("""
/**
* Decode 128 integers into {@code longs}.