mirror of https://github.com/apache/lucene.git
LUCENE-9629: Use computed masks (#2113)
Co-authored-by: 郭峰 <guofeng.my@bytedance.com>
This commit is contained in:
parent
894b6b5c88
commit
00a5637457
|
@ -250,17 +250,17 @@ final class ForUtil {
|
||||||
final int remainingBitsPerLong = shift + bitsPerValue;
|
final int remainingBitsPerLong = shift + bitsPerValue;
|
||||||
final long maskRemainingBitsPerLong;
|
final long maskRemainingBitsPerLong;
|
||||||
if (nextPrimitive == 8) {
|
if (nextPrimitive == 8) {
|
||||||
maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
|
maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
|
||||||
} else if (nextPrimitive == 16) {
|
} else if (nextPrimitive == 16) {
|
||||||
maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
|
maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
|
||||||
} else {
|
} else {
|
||||||
maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
|
maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
|
||||||
}
|
}
|
||||||
|
|
||||||
int tmpIdx = 0;
|
int tmpIdx = 0;
|
||||||
int remainingBitsPerValue = bitsPerValue;
|
int remainingBitsPerValue = bitsPerValue;
|
||||||
while (idx < numLongs) {
|
while (idx < numLongs) {
|
||||||
if (remainingBitsPerValue > remainingBitsPerLong) {
|
if (remainingBitsPerValue >= remainingBitsPerLong) {
|
||||||
remainingBitsPerValue -= remainingBitsPerLong;
|
remainingBitsPerValue -= remainingBitsPerLong;
|
||||||
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
|
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
|
||||||
if (remainingBitsPerValue == 0) {
|
if (remainingBitsPerValue == 0) {
|
||||||
|
@ -270,14 +270,14 @@ final class ForUtil {
|
||||||
} else {
|
} else {
|
||||||
final long mask1, mask2;
|
final long mask1, mask2;
|
||||||
if (nextPrimitive == 8) {
|
if (nextPrimitive == 8) {
|
||||||
mask1 = mask8(remainingBitsPerValue);
|
mask1 = MASKS8[remainingBitsPerValue];
|
||||||
mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
|
mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
|
||||||
} else if (nextPrimitive == 16) {
|
} else if (nextPrimitive == 16) {
|
||||||
mask1 = mask16(remainingBitsPerValue);
|
mask1 = MASKS16[remainingBitsPerValue];
|
||||||
mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
|
mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
|
||||||
} else {
|
} else {
|
||||||
mask1 = mask32(remainingBitsPerValue);
|
mask1 = MASKS32[remainingBitsPerValue];
|
||||||
mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
|
mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
|
||||||
}
|
}
|
||||||
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
|
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
|
||||||
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
|
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
|
||||||
|
@ -302,7 +302,7 @@ final class ForUtil {
|
||||||
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
|
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||||
final int numLongs = bitsPerValue << 1;
|
final int numLongs = bitsPerValue << 1;
|
||||||
in.readLELongs(tmp, 0, numLongs);
|
in.readLELongs(tmp, 0, numLongs);
|
||||||
final long mask = mask32(bitsPerValue);
|
final long mask = MASKS32[bitsPerValue];
|
||||||
int longsIdx = 0;
|
int longsIdx = 0;
|
||||||
int shift = 32 - bitsPerValue;
|
int shift = 32 - bitsPerValue;
|
||||||
for (; shift >= 0; shift -= bitsPerValue) {
|
for (; shift >= 0; shift -= bitsPerValue) {
|
||||||
|
@ -310,18 +310,18 @@ final class ForUtil {
|
||||||
longsIdx += numLongs;
|
longsIdx += numLongs;
|
||||||
}
|
}
|
||||||
final int remainingBitsPerLong = shift + bitsPerValue;
|
final int remainingBitsPerLong = shift + bitsPerValue;
|
||||||
final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
|
final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
|
||||||
int tmpIdx = 0;
|
int tmpIdx = 0;
|
||||||
int remainingBits = remainingBitsPerLong;
|
int remainingBits = remainingBitsPerLong;
|
||||||
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
|
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
|
||||||
int b = bitsPerValue - remainingBits;
|
int b = bitsPerValue - remainingBits;
|
||||||
long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
|
long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
|
||||||
while (b >= remainingBitsPerLong) {
|
while (b >= remainingBitsPerLong) {
|
||||||
b -= remainingBitsPerLong;
|
b -= remainingBitsPerLong;
|
||||||
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
|
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
|
||||||
}
|
}
|
||||||
if (b > 0) {
|
if (b > 0) {
|
||||||
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
|
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
|
||||||
remainingBits = remainingBitsPerLong - b;
|
remainingBits = remainingBitsPerLong - b;
|
||||||
} else {
|
} else {
|
||||||
remainingBits = remainingBitsPerLong;
|
remainingBits = remainingBitsPerLong;
|
||||||
|
@ -341,50 +341,65 @@ final class ForUtil {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final long MASK8_1 = mask8(1);
|
private static final long[] MASKS8 = new long[8];
|
||||||
private static final long MASK8_2 = mask8(2);
|
private static final long[] MASKS16 = new long[16];
|
||||||
private static final long MASK8_3 = mask8(3);
|
private static final long[] MASKS32 = new long[32];
|
||||||
private static final long MASK8_4 = mask8(4);
|
static {
|
||||||
private static final long MASK8_5 = mask8(5);
|
for (int i = 0; i < 8; ++i) {
|
||||||
private static final long MASK8_6 = mask8(6);
|
MASKS8[i] = mask8(i);
|
||||||
private static final long MASK8_7 = mask8(7);
|
}
|
||||||
private static final long MASK16_1 = mask16(1);
|
for (int i = 0; i < 16; ++i) {
|
||||||
private static final long MASK16_2 = mask16(2);
|
MASKS16[i] = mask16(i);
|
||||||
private static final long MASK16_3 = mask16(3);
|
}
|
||||||
private static final long MASK16_4 = mask16(4);
|
for (int i = 0; i < 32; ++i) {
|
||||||
private static final long MASK16_5 = mask16(5);
|
MASKS32[i] = mask32(i);
|
||||||
private static final long MASK16_6 = mask16(6);
|
}
|
||||||
private static final long MASK16_7 = mask16(7);
|
}
|
||||||
private static final long MASK16_9 = mask16(9);
|
//mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable
|
||||||
private static final long MASK16_10 = mask16(10);
|
private static final long MASK8_1 = MASKS8[1];
|
||||||
private static final long MASK16_11 = mask16(11);
|
private static final long MASK8_2 = MASKS8[2];
|
||||||
private static final long MASK16_12 = mask16(12);
|
private static final long MASK8_3 = MASKS8[3];
|
||||||
private static final long MASK16_13 = mask16(13);
|
private static final long MASK8_4 = MASKS8[4];
|
||||||
private static final long MASK16_14 = mask16(14);
|
private static final long MASK8_5 = MASKS8[5];
|
||||||
private static final long MASK16_15 = mask16(15);
|
private static final long MASK8_6 = MASKS8[6];
|
||||||
private static final long MASK32_1 = mask32(1);
|
private static final long MASK8_7 = MASKS8[7];
|
||||||
private static final long MASK32_2 = mask32(2);
|
private static final long MASK16_1 = MASKS16[1];
|
||||||
private static final long MASK32_3 = mask32(3);
|
private static final long MASK16_2 = MASKS16[2];
|
||||||
private static final long MASK32_4 = mask32(4);
|
private static final long MASK16_3 = MASKS16[3];
|
||||||
private static final long MASK32_5 = mask32(5);
|
private static final long MASK16_4 = MASKS16[4];
|
||||||
private static final long MASK32_6 = mask32(6);
|
private static final long MASK16_5 = MASKS16[5];
|
||||||
private static final long MASK32_7 = mask32(7);
|
private static final long MASK16_6 = MASKS16[6];
|
||||||
private static final long MASK32_8 = mask32(8);
|
private static final long MASK16_7 = MASKS16[7];
|
||||||
private static final long MASK32_9 = mask32(9);
|
private static final long MASK16_9 = MASKS16[9];
|
||||||
private static final long MASK32_10 = mask32(10);
|
private static final long MASK16_10 = MASKS16[10];
|
||||||
private static final long MASK32_11 = mask32(11);
|
private static final long MASK16_11 = MASKS16[11];
|
||||||
private static final long MASK32_12 = mask32(12);
|
private static final long MASK16_12 = MASKS16[12];
|
||||||
private static final long MASK32_13 = mask32(13);
|
private static final long MASK16_13 = MASKS16[13];
|
||||||
private static final long MASK32_14 = mask32(14);
|
private static final long MASK16_14 = MASKS16[14];
|
||||||
private static final long MASK32_15 = mask32(15);
|
private static final long MASK16_15 = MASKS16[15];
|
||||||
private static final long MASK32_17 = mask32(17);
|
private static final long MASK32_1 = MASKS32[1];
|
||||||
private static final long MASK32_18 = mask32(18);
|
private static final long MASK32_2 = MASKS32[2];
|
||||||
private static final long MASK32_19 = mask32(19);
|
private static final long MASK32_3 = MASKS32[3];
|
||||||
private static final long MASK32_20 = mask32(20);
|
private static final long MASK32_4 = MASKS32[4];
|
||||||
private static final long MASK32_21 = mask32(21);
|
private static final long MASK32_5 = MASKS32[5];
|
||||||
private static final long MASK32_22 = mask32(22);
|
private static final long MASK32_6 = MASKS32[6];
|
||||||
private static final long MASK32_23 = mask32(23);
|
private static final long MASK32_7 = MASKS32[7];
|
||||||
private static final long MASK32_24 = mask32(24);
|
private static final long MASK32_8 = MASKS32[8];
|
||||||
|
private static final long MASK32_9 = MASKS32[9];
|
||||||
|
private static final long MASK32_10 = MASKS32[10];
|
||||||
|
private static final long MASK32_11 = MASKS32[11];
|
||||||
|
private static final long MASK32_12 = MASKS32[12];
|
||||||
|
private static final long MASK32_13 = MASKS32[13];
|
||||||
|
private static final long MASK32_14 = MASKS32[14];
|
||||||
|
private static final long MASK32_15 = MASKS32[15];
|
||||||
|
private static final long MASK32_17 = MASKS32[17];
|
||||||
|
private static final long MASK32_18 = MASKS32[18];
|
||||||
|
private static final long MASK32_19 = MASKS32[19];
|
||||||
|
private static final long MASK32_20 = MASKS32[20];
|
||||||
|
private static final long MASK32_21 = MASKS32[21];
|
||||||
|
private static final long MASK32_22 = MASKS32[22];
|
||||||
|
private static final long MASK32_23 = MASKS32[23];
|
||||||
|
private static final long MASK32_24 = MASKS32[24];
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -21,6 +21,7 @@ from fractions import gcd
|
||||||
|
|
||||||
MAX_SPECIALIZED_BITS_PER_VALUE = 24
|
MAX_SPECIALIZED_BITS_PER_VALUE = 24
|
||||||
OUTPUT_FILE = "ForUtil.java"
|
OUTPUT_FILE = "ForUtil.java"
|
||||||
|
PRIMITIVE_SIZE = [8, 16, 32]
|
||||||
HEADER = """// This file has been automatically generated, DO NOT EDIT
|
HEADER = """// This file has been automatically generated, DO NOT EDIT
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -273,17 +274,17 @@ final class ForUtil {
|
||||||
final int remainingBitsPerLong = shift + bitsPerValue;
|
final int remainingBitsPerLong = shift + bitsPerValue;
|
||||||
final long maskRemainingBitsPerLong;
|
final long maskRemainingBitsPerLong;
|
||||||
if (nextPrimitive == 8) {
|
if (nextPrimitive == 8) {
|
||||||
maskRemainingBitsPerLong = mask8(remainingBitsPerLong);
|
maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong];
|
||||||
} else if (nextPrimitive == 16) {
|
} else if (nextPrimitive == 16) {
|
||||||
maskRemainingBitsPerLong = mask16(remainingBitsPerLong);
|
maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong];
|
||||||
} else {
|
} else {
|
||||||
maskRemainingBitsPerLong = mask32(remainingBitsPerLong);
|
maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong];
|
||||||
}
|
}
|
||||||
|
|
||||||
int tmpIdx = 0;
|
int tmpIdx = 0;
|
||||||
int remainingBitsPerValue = bitsPerValue;
|
int remainingBitsPerValue = bitsPerValue;
|
||||||
while (idx < numLongs) {
|
while (idx < numLongs) {
|
||||||
if (remainingBitsPerValue > remainingBitsPerLong) {
|
if (remainingBitsPerValue >= remainingBitsPerLong) {
|
||||||
remainingBitsPerValue -= remainingBitsPerLong;
|
remainingBitsPerValue -= remainingBitsPerLong;
|
||||||
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
|
tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong;
|
||||||
if (remainingBitsPerValue == 0) {
|
if (remainingBitsPerValue == 0) {
|
||||||
|
@ -293,14 +294,14 @@ final class ForUtil {
|
||||||
} else {
|
} else {
|
||||||
final long mask1, mask2;
|
final long mask1, mask2;
|
||||||
if (nextPrimitive == 8) {
|
if (nextPrimitive == 8) {
|
||||||
mask1 = mask8(remainingBitsPerValue);
|
mask1 = MASKS8[remainingBitsPerValue];
|
||||||
mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue);
|
mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue];
|
||||||
} else if (nextPrimitive == 16) {
|
} else if (nextPrimitive == 16) {
|
||||||
mask1 = mask16(remainingBitsPerValue);
|
mask1 = MASKS16[remainingBitsPerValue];
|
||||||
mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue);
|
mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue];
|
||||||
} else {
|
} else {
|
||||||
mask1 = mask32(remainingBitsPerValue);
|
mask1 = MASKS32[remainingBitsPerValue];
|
||||||
mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue);
|
mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue];
|
||||||
}
|
}
|
||||||
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
|
tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue);
|
||||||
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
|
remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue;
|
||||||
|
@ -325,7 +326,7 @@ final class ForUtil {
|
||||||
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
|
private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException {
|
||||||
final int numLongs = bitsPerValue << 1;
|
final int numLongs = bitsPerValue << 1;
|
||||||
in.readLELongs(tmp, 0, numLongs);
|
in.readLELongs(tmp, 0, numLongs);
|
||||||
final long mask = mask32(bitsPerValue);
|
final long mask = MASKS32[bitsPerValue];
|
||||||
int longsIdx = 0;
|
int longsIdx = 0;
|
||||||
int shift = 32 - bitsPerValue;
|
int shift = 32 - bitsPerValue;
|
||||||
for (; shift >= 0; shift -= bitsPerValue) {
|
for (; shift >= 0; shift -= bitsPerValue) {
|
||||||
|
@ -333,18 +334,18 @@ final class ForUtil {
|
||||||
longsIdx += numLongs;
|
longsIdx += numLongs;
|
||||||
}
|
}
|
||||||
final int remainingBitsPerLong = shift + bitsPerValue;
|
final int remainingBitsPerLong = shift + bitsPerValue;
|
||||||
final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong);
|
final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong];
|
||||||
int tmpIdx = 0;
|
int tmpIdx = 0;
|
||||||
int remainingBits = remainingBitsPerLong;
|
int remainingBits = remainingBitsPerLong;
|
||||||
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
|
for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) {
|
||||||
int b = bitsPerValue - remainingBits;
|
int b = bitsPerValue - remainingBits;
|
||||||
long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b;
|
long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b;
|
||||||
while (b >= remainingBitsPerLong) {
|
while (b >= remainingBitsPerLong) {
|
||||||
b -= remainingBitsPerLong;
|
b -= remainingBitsPerLong;
|
||||||
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
|
l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b;
|
||||||
}
|
}
|
||||||
if (b > 0) {
|
if (b > 0) {
|
||||||
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b);
|
l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b];
|
||||||
remainingBits = remainingBitsPerLong - b;
|
remainingBits = remainingBitsPerLong - b;
|
||||||
} else {
|
} else {
|
||||||
remainingBits = remainingBitsPerLong;
|
remainingBits = remainingBitsPerLong;
|
||||||
|
@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long,
|
||||||
num_values /= 2
|
num_values /= 2
|
||||||
iteration *= 2
|
iteration *= 2
|
||||||
|
|
||||||
|
|
||||||
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
|
f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long))
|
||||||
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
|
f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values))
|
||||||
tmp_idx = 0
|
tmp_idx = 0
|
||||||
|
@ -450,11 +452,21 @@ def writeDecode(bpv, f):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
f = open(OUTPUT_FILE, 'w')
|
f = open(OUTPUT_FILE, 'w')
|
||||||
f.write(HEADER)
|
f.write(HEADER)
|
||||||
for primitive_size in [8, 16, 32]:
|
for primitive_size in PRIMITIVE_SIZE:
|
||||||
|
f.write(' private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size))
|
||||||
|
f.write(' static {\n')
|
||||||
|
for primitive_size in PRIMITIVE_SIZE:
|
||||||
|
f.write(' for (int i = 0; i < %d; ++i) {\n' %primitive_size)
|
||||||
|
f.write(' MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size))
|
||||||
|
f.write(' }\n')
|
||||||
|
f.write(' }\n')
|
||||||
|
f.write(' //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n')
|
||||||
|
for primitive_size in PRIMITIVE_SIZE:
|
||||||
for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
|
for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)):
|
||||||
if bpv * 2 != primitive_size or primitive_size == 8:
|
if bpv * 2 != primitive_size or primitive_size == 8:
|
||||||
f.write(' private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv))
|
f.write(' private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
|
||||||
f.write("""
|
f.write("""
|
||||||
/**
|
/**
|
||||||
* Decode 128 integers into {@code longs}.
|
* Decode 128 integers into {@code longs}.
|
||||||
|
|
Loading…
Reference in New Issue