LUCENE-9629: Use computed masks (#2113)

Co-authored-by: 郭峰 <guofeng.my@bytedance.com>
This commit is contained in:
gf2121 2020-12-18 08:59:40 -06:00 committed by GitHub
parent 894b6b5c88
commit 00a5637457
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 74 deletions

View File

@ -250,17 +250,17 @@ final class ForUtil {
final int remainingBitsPerLong = shift + bitsPerValue; final int remainingBitsPerLong = shift + bitsPerValue;
final long maskRemainingBitsPerLong; final long maskRemainingBitsPerLong;
if (nextPrimitive == 8) { if (nextPrimitive == 8) {
maskRemainingBitsPerLong = mask8(remainingBitsPerLong); maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
} else if (nextPrimitive == 16) { } else if (nextPrimitive == 16) {
maskRemainingBitsPerLong = mask16(remainingBitsPerLong); maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
} else { } else {
maskRemainingBitsPerLong = mask32(remainingBitsPerLong); maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
} }
int tmpIdx = 0; int tmpIdx = 0;
int remainingBitsPerValue = bitsPerValue; int remainingBitsPerValue = bitsPerValue;
while (idx < numLongs) { while (idx < numLongs) {
if (remainingBitsPerValue > remainingBitsPerLong) { if (remainingBitsPerValue >= remainingBitsPerLong) {
remainingBitsPerValue -= remainingBitsPerLong; remainingBitsPerValue -= remainingBitsPerLong;
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong; tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
if (remainingBitsPerValue == 0) { if (remainingBitsPerValue == 0) {
@ -270,14 +270,14 @@ final class ForUtil {
} else { } else {
final long mask1, mask2; final long mask1, mask2;
if (nextPrimitive == 8) { if (nextPrimitive == 8) {
mask1 = mask8(remainingBitsPerValue); mask1 = MASKS8[remainingBitsPerValue];
mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue); mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
} else if (nextPrimitive == 16) { } else if (nextPrimitive == 16) {
mask1 = mask16(remainingBitsPerValue); mask1 = MASKS16[remainingBitsPerValue];
mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue); mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
} else { } else {
mask1 = mask32(remainingBitsPerValue); mask1 = MASKS32[remainingBitsPerValue];
mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue); mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
} }
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue); tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue; remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@ -302,7 +302,7 @@ final class ForUtil {
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException { private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
final int numLongs = bitsPerValue << 1; final int numLongs = bitsPerValue << 1;
in.readLELongs(tmp, 0, numLongs); in.readLELongs(tmp, 0, numLongs);
final long mask = mask32(bitsPerValue); final long mask = MASKS32[bitsPerValue];
int longsIdx = 0; int longsIdx = 0;
int shift = 32 - bitsPerValue; int shift = 32 - bitsPerValue;
for (; shift >= 0; shift -= bitsPerValue) { for (; shift >= 0; shift -= bitsPerValue) {
@ -310,18 +310,18 @@ final class ForUtil {
longsIdx += numLongs; longsIdx += numLongs;
} }
final int remainingBitsPerLong = shift + bitsPerValue; final int remainingBitsPerLong = shift + bitsPerValue;
final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong); final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
int tmpIdx = 0; int tmpIdx = 0;
int remainingBits = remainingBitsPerLong; int remainingBits = remainingBitsPerLong;
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) { for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
int b = bitsPerValue - remainingBits; int b = bitsPerValue - remainingBits;
long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b; long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
while (b >= remainingBitsPerLong) { while (b >= remainingBitsPerLong) {
b -= remainingBitsPerLong; b -= remainingBitsPerLong;
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b; l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
} }
if (b > 0) { if (b > 0) {
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b); l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
remainingBits = remainingBitsPerLong - b; remainingBits = remainingBitsPerLong - b;
} else { } else {
remainingBits = remainingBitsPerLong; remainingBits = remainingBitsPerLong;
@ -341,50 +341,65 @@ final class ForUtil {
} }
} }
private static final long MASK8_1 = mask8(1); private static final long[] MASKS8 = new long[8];
private static final long MASK8_2 = mask8(2); private static final long[] MASKS16 = new long[16];
private static final long MASK8_3 = mask8(3); private static final long[] MASKS32 = new long[32];
private static final long MASK8_4 = mask8(4); static {
private static final long MASK8_5 = mask8(5); for (int i = 0; i < 8; ++i) {
private static final long MASK8_6 = mask8(6); MASKS8[i] = mask8(i);
private static final long MASK8_7 = mask8(7); }
private static final long MASK16_1 = mask16(1); for (int i = 0; i < 16; ++i) {
private static final long MASK16_2 = mask16(2); MASKS16[i] = mask16(i);
private static final long MASK16_3 = mask16(3); }
private static final long MASK16_4 = mask16(4); for (int i = 0; i < 32; ++i) {
private static final long MASK16_5 = mask16(5); MASKS32[i] = mask32(i);
private static final long MASK16_6 = mask16(6); }
private static final long MASK16_7 = mask16(7); }
private static final long MASK16_9 = mask16(9); //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable
private static final long MASK16_10 = mask16(10); private static final long MASK8_1 = MASKS8[1];
private static final long MASK16_11 = mask16(11); private static final long MASK8_2 = MASKS8[2];
private static final long MASK16_12 = mask16(12); private static final long MASK8_3 = MASKS8[3];
private static final long MASK16_13 = mask16(13); private static final long MASK8_4 = MASKS8[4];
private static final long MASK16_14 = mask16(14); private static final long MASK8_5 = MASKS8[5];
private static final long MASK16_15 = mask16(15); private static final long MASK8_6 = MASKS8[6];
private static final long MASK32_1 = mask32(1); private static final long MASK8_7 = MASKS8[7];
private static final long MASK32_2 = mask32(2); private static final long MASK16_1 = MASKS16[1];
private static final long MASK32_3 = mask32(3); private static final long MASK16_2 = MASKS16[2];
private static final long MASK32_4 = mask32(4); private static final long MASK16_3 = MASKS16[3];
private static final long MASK32_5 = mask32(5); private static final long MASK16_4 = MASKS16[4];
private static final long MASK32_6 = mask32(6); private static final long MASK16_5 = MASKS16[5];
private static final long MASK32_7 = mask32(7); private static final long MASK16_6 = MASKS16[6];
private static final long MASK32_8 = mask32(8); private static final long MASK16_7 = MASKS16[7];
private static final long MASK32_9 = mask32(9); private static final long MASK16_9 = MASKS16[9];
private static final long MASK32_10 = mask32(10); private static final long MASK16_10 = MASKS16[10];
private static final long MASK32_11 = mask32(11); private static final long MASK16_11 = MASKS16[11];
private static final long MASK32_12 = mask32(12); private static final long MASK16_12 = MASKS16[12];
private static final long MASK32_13 = mask32(13); private static final long MASK16_13 = MASKS16[13];
private static final long MASK32_14 = mask32(14); private static final long MASK16_14 = MASKS16[14];
private static final long MASK32_15 = mask32(15); private static final long MASK16_15 = MASKS16[15];
private static final long MASK32_17 = mask32(17); private static final long MASK32_1 = MASKS32[1];
private static final long MASK32_18 = mask32(18); private static final long MASK32_2 = MASKS32[2];
private static final long MASK32_19 = mask32(19); private static final long MASK32_3 = MASKS32[3];
private static final long MASK32_20 = mask32(20); private static final long MASK32_4 = MASKS32[4];
private static final long MASK32_21 = mask32(21); private static final long MASK32_5 = MASKS32[5];
private static final long MASK32_22 = mask32(22); private static final long MASK32_6 = MASKS32[6];
private static final long MASK32_23 = mask32(23); private static final long MASK32_7 = MASKS32[7];
private static final long MASK32_24 = mask32(24); private static final long MASK32_8 = MASKS32[8];
private static final long MASK32_9 = MASKS32[9];
private static final long MASK32_10 = MASKS32[10];
private static final long MASK32_11 = MASKS32[11];
private static final long MASK32_12 = MASKS32[12];
private static final long MASK32_13 = MASKS32[13];
private static final long MASK32_14 = MASKS32[14];
private static final long MASK32_15 = MASKS32[15];
private static final long MASK32_17 = MASKS32[17];
private static final long MASK32_18 = MASKS32[18];
private static final long MASK32_19 = MASKS32[19];
private static final long MASK32_20 = MASKS32[20];
private static final long MASK32_21 = MASKS32[21];
private static final long MASK32_22 = MASKS32[22];
private static final long MASK32_23 = MASKS32[23];
private static final long MASK32_24 = MASKS32[24];
/** /**

View File

@ -21,6 +21,7 @@ from fractions import gcd
MAX_SPECIALIZED_BITS_PER_VALUE = 24 MAX_SPECIALIZED_BITS_PER_VALUE = 24
OUTPUT_FILE = "ForUtil.java" OUTPUT_FILE = "ForUtil.java"
PRIMITIVE_SIZE = [8, 16, 32]
HEADER = """// This file has been automatically generated, DO NOT EDIT HEADER = """// This file has been automatically generated, DO NOT EDIT
/* /*
@ -273,17 +274,17 @@ final class ForUtil {
final int remainingBitsPerLong = shift + bitsPerValue; final int remainingBitsPerLong = shift + bitsPerValue;
final long maskRemainingBitsPerLong; final long maskRemainingBitsPerLong;
if (nextPrimitive == 8) { if (nextPrimitive == 8) {
maskRemainingBitsPerLong = mask8(remainingBitsPerLong); maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
} else if (nextPrimitive == 16) { } else if (nextPrimitive == 16) {
maskRemainingBitsPerLong = mask16(remainingBitsPerLong); maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
} else { } else {
maskRemainingBitsPerLong = mask32(remainingBitsPerLong); maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
} }
int tmpIdx = 0; int tmpIdx = 0;
int remainingBitsPerValue = bitsPerValue; int remainingBitsPerValue = bitsPerValue;
while (idx < numLongs) { while (idx < numLongs) {
if (remainingBitsPerValue > remainingBitsPerLong) { if (remainingBitsPerValue >= remainingBitsPerLong) {
remainingBitsPerValue -= remainingBitsPerLong; remainingBitsPerValue -= remainingBitsPerLong;
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong; tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
if (remainingBitsPerValue == 0) { if (remainingBitsPerValue == 0) {
@ -293,14 +294,14 @@ final class ForUtil {
} else { } else {
final long mask1, mask2; final long mask1, mask2;
if (nextPrimitive == 8) { if (nextPrimitive == 8) {
mask1 = mask8(remainingBitsPerValue); mask1 = MASKS8[remainingBitsPerValue];
mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue); mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
} else if (nextPrimitive == 16) { } else if (nextPrimitive == 16) {
mask1 = mask16(remainingBitsPerValue); mask1 = MASKS16[remainingBitsPerValue];
mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue); mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
} else { } else {
mask1 = mask32(remainingBitsPerValue); mask1 = MASKS32[remainingBitsPerValue];
mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue); mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
} }
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue); tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue; remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
@ -325,7 +326,7 @@ final class ForUtil {
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException { private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
final int numLongs = bitsPerValue << 1; final int numLongs = bitsPerValue << 1;
in.readLELongs(tmp, 0, numLongs); in.readLELongs(tmp, 0, numLongs);
final long mask = mask32(bitsPerValue); final long mask = MASKS32[bitsPerValue];
int longsIdx = 0; int longsIdx = 0;
int shift = 32 - bitsPerValue; int shift = 32 - bitsPerValue;
for (; shift >= 0; shift -= bitsPerValue) { for (; shift >= 0; shift -= bitsPerValue) {
@ -333,18 +334,18 @@ final class ForUtil {
longsIdx += numLongs; longsIdx += numLongs;
} }
final int remainingBitsPerLong = shift + bitsPerValue; final int remainingBitsPerLong = shift + bitsPerValue;
final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong); final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
int tmpIdx = 0; int tmpIdx = 0;
int remainingBits = remainingBitsPerLong; int remainingBits = remainingBitsPerLong;
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) { for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
int b = bitsPerValue - remainingBits; int b = bitsPerValue - remainingBits;
long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b; long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
while (b >= remainingBitsPerLong) { while (b >= remainingBitsPerLong) {
b -= remainingBitsPerLong; b -= remainingBitsPerLong;
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b; l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
} }
if (b > 0) { if (b > 0) {
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b); l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
remainingBits = remainingBitsPerLong - b; remainingBits = remainingBitsPerLong - b;
} else { } else {
remainingBits = remainingBitsPerLong; remainingBits = remainingBitsPerLong;
@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long,
num_values /= 2 num_values /= 2
iteration *= 2 iteration *= 2
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long)) f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values)) f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
tmp_idx = 0 tmp_idx = 0
@ -450,11 +452,21 @@ def writeDecode(bpv, f):
if __name__ == '__main__': if __name__ == '__main__':
f = open(OUTPUT_FILE, 'w') f = open(OUTPUT_FILE, 'w')
f.write(HEADER) f.write(HEADER)
for primitive_size in [8, 16, 32]: for primitive_size in PRIMITIVE_SIZE:
f.write(' private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size))
f.write(' static {\n')
for primitive_size in PRIMITIVE_SIZE:
f.write(' for (int i = 0; i < %d; ++i) {\n' %primitive_size)
f.write(' MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size))
f.write(' }\n')
f.write(' }\n')
f.write(' //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n')
for primitive_size in PRIMITIVE_SIZE:
for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)): for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
if bpv * 2 != primitive_size or primitive_size == 8: if bpv * 2 != primitive_size or primitive_size == 8:
f.write(' private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv)) f.write(' private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv))
f.write('\n') f.write('\n')
f.write(""" f.write("""
/** /**
* Decode 128 integers into {@code longs}. * Decode 128 integers into {@code longs}.