From 00a56374571c4ab51191659ebe57552e5f245881 Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Fri, 18 Dec 2020 08:59:40 -0600 Subject: [PATCH] LUCENE-9629: Use computed masks (#2113) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 郭峰 --- .../lucene/codecs/lucene84/ForUtil.java | 131 ++++++++++-------- .../lucene/codecs/lucene84/gen_ForUtil.py | 44 +++--- 2 files changed, 101 insertions(+), 74 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java index eb07ec18f89..266624f399b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/ForUtil.java @@ -250,17 +250,17 @@ final class ForUtil { final int remainingBitsPerLong = shift + bitsPerValue; final long maskRemainingBitsPerLong; if (nextPrimitive == 8) { - maskRemainingBitsPerLong = mask8(remainingBitsPerLong); + maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong]; } else if (nextPrimitive == 16) { - maskRemainingBitsPerLong = mask16(remainingBitsPerLong); + maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong]; } else { - maskRemainingBitsPerLong = mask32(remainingBitsPerLong); + maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong]; } int tmpIdx = 0; int remainingBitsPerValue = bitsPerValue; while (idx < numLongs) { - if (remainingBitsPerValue > remainingBitsPerLong) { + if (remainingBitsPerValue >= remainingBitsPerLong) { remainingBitsPerValue -= remainingBitsPerLong; tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong; if (remainingBitsPerValue == 0) { @@ -270,14 +270,14 @@ final class ForUtil { } else { final long mask1, mask2; if (nextPrimitive == 8) { - mask1 = mask8(remainingBitsPerValue); - mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue); + mask1 = MASKS8[remainingBitsPerValue]; + mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue]; } else if (nextPrimitive == 16) { - mask1 = mask16(remainingBitsPerValue); - mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue); + mask1 = MASKS16[remainingBitsPerValue]; + mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue]; } else { - mask1 = mask32(remainingBitsPerValue); - mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue); + mask1 = MASKS32[remainingBitsPerValue]; + mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue]; } tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue); remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue; @@ -302,7 +302,7 @@ final class ForUtil { private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException { final int numLongs = bitsPerValue << 1; in.readLELongs(tmp, 0, numLongs); - final long mask = mask32(bitsPerValue); + final long mask = MASKS32[bitsPerValue]; int longsIdx = 0; int shift = 32 - bitsPerValue; for (; shift >= 0; shift -= bitsPerValue) { @@ -310,18 +310,18 @@ final class ForUtil { longsIdx += numLongs; } final int remainingBitsPerLong = shift + bitsPerValue; - final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong); + final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong]; int tmpIdx = 0; int remainingBits = remainingBitsPerLong; for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) { int b = bitsPerValue - remainingBits; - long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b; + long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b; while (b >= remainingBitsPerLong) { b -= remainingBitsPerLong; l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b; } if (b > 0) { - l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b); + l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b]; remainingBits = remainingBitsPerLong - b; } else { remainingBits = remainingBitsPerLong; @@ -341,50 +341,65 @@ final class ForUtil { } } - private static final long MASK8_1 = mask8(1); - private static final long MASK8_2 = mask8(2); - private static final long MASK8_3 = mask8(3); - private static final long MASK8_4 = mask8(4); - private static final long MASK8_5 = mask8(5); - private static final long MASK8_6 = mask8(6); - private static final long MASK8_7 = mask8(7); - private static final long MASK16_1 = mask16(1); - private static final long MASK16_2 = mask16(2); - private static final long MASK16_3 = mask16(3); - private static final long MASK16_4 = mask16(4); - private static final long MASK16_5 = mask16(5); - private static final long MASK16_6 = mask16(6); - private static final long MASK16_7 = mask16(7); - private static final long MASK16_9 = mask16(9); - private static final long MASK16_10 = mask16(10); - private static final long MASK16_11 = mask16(11); - private static final long MASK16_12 = mask16(12); - private static final long MASK16_13 = mask16(13); - private static final long MASK16_14 = mask16(14); - private static final long MASK16_15 = mask16(15); - private static final long MASK32_1 = mask32(1); - private static final long MASK32_2 = mask32(2); - private static final long MASK32_3 = mask32(3); - private static final long MASK32_4 = mask32(4); - private static final long MASK32_5 = mask32(5); - private static final long MASK32_6 = mask32(6); - private static final long MASK32_7 = mask32(7); - private static final long MASK32_8 = mask32(8); - private static final long MASK32_9 = mask32(9); - private static final long MASK32_10 = mask32(10); - private static final long MASK32_11 = mask32(11); - private static final long MASK32_12 = mask32(12); - private static final long MASK32_13 = mask32(13); - private static final long MASK32_14 = mask32(14); - private static final long MASK32_15 = mask32(15); - private static final long MASK32_17 = mask32(17); - private static final long MASK32_18 = mask32(18); - private static final long MASK32_19 = mask32(19); - private static final long MASK32_20 = mask32(20); - private static final long MASK32_21 = mask32(21); - private static final long MASK32_22 = mask32(22); - private static final long MASK32_23 = mask32(23); - private static final long MASK32_24 = mask32(24); + private static final long[] MASKS8 = new long[8]; + private static final long[] MASKS16 = new long[16]; + private static final long[] MASKS32 = new long[32]; + static { + for (int i = 0; i < 8; ++i) { + MASKS8[i] = mask8(i); + } + for (int i = 0; i < 16; ++i) { + MASKS16[i] = mask16(i); + } + for (int i = 0; i < 32; ++i) { + MASKS32[i] = mask32(i); + } + } + //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable + private static final long MASK8_1 = MASKS8[1]; + private static final long MASK8_2 = MASKS8[2]; + private static final long MASK8_3 = MASKS8[3]; + private static final long MASK8_4 = MASKS8[4]; + private static final long MASK8_5 = MASKS8[5]; + private static final long MASK8_6 = MASKS8[6]; + private static final long MASK8_7 = MASKS8[7]; + private static final long MASK16_1 = MASKS16[1]; + private static final long MASK16_2 = MASKS16[2]; + private static final long MASK16_3 = MASKS16[3]; + private static final long MASK16_4 = MASKS16[4]; + private static final long MASK16_5 = MASKS16[5]; + private static final long MASK16_6 = MASKS16[6]; + private static final long MASK16_7 = MASKS16[7]; + private static final long MASK16_9 = MASKS16[9]; + private static final long MASK16_10 = MASKS16[10]; + private static final long MASK16_11 = MASKS16[11]; + private static final long MASK16_12 = MASKS16[12]; + private static final long MASK16_13 = MASKS16[13]; + private static final long MASK16_14 = MASKS16[14]; + private static final long MASK16_15 = MASKS16[15]; + private static final long MASK32_1 = MASKS32[1]; + private static final long MASK32_2 = MASKS32[2]; + private static final long MASK32_3 = MASKS32[3]; + private static final long MASK32_4 = MASKS32[4]; + private static final long MASK32_5 = MASKS32[5]; + private static final long MASK32_6 = MASKS32[6]; + private static final long MASK32_7 = MASKS32[7]; + private static final long MASK32_8 = MASKS32[8]; + private static final long MASK32_9 = MASKS32[9]; + private static final long MASK32_10 = MASKS32[10]; + private static final long MASK32_11 = MASKS32[11]; + private static final long MASK32_12 = MASKS32[12]; + private static final long MASK32_13 = MASKS32[13]; + private static final long MASK32_14 = MASKS32[14]; + private static final long MASK32_15 = MASKS32[15]; + private static final long MASK32_17 = MASKS32[17]; + private static final long MASK32_18 = MASKS32[18]; + private static final long MASK32_19 = MASKS32[19]; + private static final long MASK32_20 = MASKS32[20]; + private static final long MASK32_21 = MASKS32[21]; + private static final long MASK32_22 = MASKS32[22]; + private static final long MASK32_23 = MASKS32[23]; + private static final long MASK32_24 = MASKS32[24]; /** diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py index 94f31e24a95..30256182122 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene84/gen_ForUtil.py @@ -21,6 +21,7 @@ from fractions import gcd MAX_SPECIALIZED_BITS_PER_VALUE = 24 OUTPUT_FILE = "ForUtil.java" +PRIMITIVE_SIZE = [8, 16, 32] HEADER = """// This file has been automatically generated, DO NOT EDIT /* @@ -273,17 +274,17 @@ final class ForUtil { final int remainingBitsPerLong = shift + bitsPerValue; final long maskRemainingBitsPerLong; if (nextPrimitive == 8) { - maskRemainingBitsPerLong = mask8(remainingBitsPerLong); + maskRemainingBitsPerLong = MASKS8[remainingBitsPerLong]; } else if (nextPrimitive == 16) { - maskRemainingBitsPerLong = mask16(remainingBitsPerLong); + maskRemainingBitsPerLong = MASKS16[remainingBitsPerLong]; } else { - maskRemainingBitsPerLong = mask32(remainingBitsPerLong); + maskRemainingBitsPerLong = MASKS32[remainingBitsPerLong]; } int tmpIdx = 0; int remainingBitsPerValue = bitsPerValue; while (idx < numLongs) { - if (remainingBitsPerValue > remainingBitsPerLong) { + if (remainingBitsPerValue >= remainingBitsPerLong) { remainingBitsPerValue -= remainingBitsPerLong; tmp[tmpIdx++] |= (longs[idx] >>> remainingBitsPerValue) & maskRemainingBitsPerLong; if (remainingBitsPerValue == 0) { @@ -293,14 +294,14 @@ final class ForUtil { } else { final long mask1, mask2; if (nextPrimitive == 8) { - mask1 = mask8(remainingBitsPerValue); - mask2 = mask8(remainingBitsPerLong - remainingBitsPerValue); + mask1 = MASKS8[remainingBitsPerValue]; + mask2 = MASKS8[remainingBitsPerLong - remainingBitsPerValue]; } else if (nextPrimitive == 16) { - mask1 = mask16(remainingBitsPerValue); - mask2 = mask16(remainingBitsPerLong - remainingBitsPerValue); + mask1 = MASKS16[remainingBitsPerValue]; + mask2 = MASKS16[remainingBitsPerLong - remainingBitsPerValue]; } else { - mask1 = mask32(remainingBitsPerValue); - mask2 = mask32(remainingBitsPerLong - remainingBitsPerValue); + mask1 = MASKS32[remainingBitsPerValue]; + mask2 = MASKS32[remainingBitsPerLong - remainingBitsPerValue]; } tmp[tmpIdx] |= (longs[idx++] & mask1) << (remainingBitsPerLong - remainingBitsPerValue); remainingBitsPerValue = bitsPerValue - remainingBitsPerLong + remainingBitsPerValue; @@ -325,7 +326,7 @@ final class ForUtil { private static void decodeSlow(int bitsPerValue, DataInput in, long[] tmp, long[] longs) throws IOException { final int numLongs = bitsPerValue << 1; in.readLELongs(tmp, 0, numLongs); - final long mask = mask32(bitsPerValue); + final long mask = MASKS32[bitsPerValue]; int longsIdx = 0; int shift = 32 - bitsPerValue; for (; shift >= 0; shift -= bitsPerValue) { @@ -333,18 +334,18 @@ final class ForUtil { longsIdx += numLongs; } final int remainingBitsPerLong = shift + bitsPerValue; - final long mask32RemainingBitsPerLong = mask32(remainingBitsPerLong); + final long mask32RemainingBitsPerLong = MASKS32[remainingBitsPerLong]; int tmpIdx = 0; int remainingBits = remainingBitsPerLong; for (; longsIdx < BLOCK_SIZE / 2; ++longsIdx) { int b = bitsPerValue - remainingBits; - long l = (tmp[tmpIdx++] & mask32(remainingBits)) << b; + long l = (tmp[tmpIdx++] & MASKS32[remainingBits]) << b; while (b >= remainingBitsPerLong) { b -= remainingBitsPerLong; l |= (tmp[tmpIdx++] & mask32RemainingBitsPerLong) << b; } if (b > 0) { - l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & mask32(b); + l |= (tmp[tmpIdx] >>> (remainingBitsPerLong-b)) & MASKS32[b]; remainingBits = remainingBitsPerLong - b; } else { remainingBits = remainingBitsPerLong; @@ -374,6 +375,7 @@ def writeRemainderWithSIMDOptimize(bpv, next_primitive, remaining_bits_per_long, num_values /= 2 iteration *= 2 + f.write(' shiftLongs(tmp, %d, tmp, 0, 0, MASK%d_%d);\n' % (iteration * num_longs, next_primitive, remaining_bits_per_long)) f.write(' for (int iter = 0, tmpIdx = 0, longsIdx = %d; iter < %d; ++iter, tmpIdx += %d, longsIdx += %d) {\n' %(o, iteration, num_longs, num_values)) tmp_idx = 0 @@ -450,11 +452,21 @@ def writeDecode(bpv, f): if __name__ == '__main__': f = open(OUTPUT_FILE, 'w') f.write(HEADER) - for primitive_size in [8, 16, 32]: + for primitive_size in PRIMITIVE_SIZE: + f.write(' private static final long[] MASKS%d = new long[%d];\n' %(primitive_size, primitive_size)) + f.write(' static {\n') + for primitive_size in PRIMITIVE_SIZE: + f.write(' for (int i = 0; i < %d; ++i) {\n' %primitive_size) + f.write(' MASKS%d[i] = mask%d(i);\n' %(primitive_size, primitive_size)) + f.write(' }\n') + f.write(' }\n') + f.write(' //mark values in array as final longs to avoid the cost of reading array, arrays should only be used when the idx is a variable\n') + for primitive_size in PRIMITIVE_SIZE: for bpv in range(1, min(MAX_SPECIALIZED_BITS_PER_VALUE + 1, primitive_size)): if bpv * 2 != primitive_size or primitive_size == 8: - f.write(' private static final long MASK%d_%d = mask%d(%d);\n' %(primitive_size, bpv, primitive_size, bpv)) + f.write(' private static final long MASK%d_%d = MASKS%d[%d];\n' %(primitive_size, bpv, primitive_size, bpv)) f.write('\n') + f.write(""" /** * Decode 128 integers into {@code longs}.