Simplify PForUtil construction and cleanup its code gen a little (#13932)

Generate cleaner code for PForUtil that has no dead parameters.

Also:
PForUtil instances always create their own `ForUtil`, so we can inline
that into the field declaration. Also, we can save cycles
for accessing the input on PostingsDecodingUtil.

Surprisingly, the combination of these cleanups yields a small but
statistically fully visible speedup that the compiler isn't able to get
to on its own it seems.
This commit is contained in:
Armin Braun 2024-10-21 15:10:08 +02:00 committed by GitHub
parent 05e06e51ec
commit 66f22fa0fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 54 additions and 44 deletions

View File

@ -1,4 +1,4 @@
{
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForDeltaUtil.java": "f561578ccb6a95364bb62c5ed86b38ff0b4a009d",
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForDeltaUtil.py": "eea1a71be9da8a13fdd979354dc4a8c6edf21be1"
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForDeltaUtil.java": "b662da5848b0decc8bceb4225f433875ae9e3c11",
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForDeltaUtil.py": "01787b97bbe79edb7703498cef8ddb85901a6b1e"
}

View File

@ -1,4 +1,4 @@
{
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java": "159e82388346fde147924d5e15ca65df4dd63b9a",
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py": "66dc8813160feae2a37d8b50474f5f9830b6cb22"
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/ForUtil.java": "02e0c8c290e65d0314664fde24c9331bdec44925",
"lucene/core/src/java/org/apache/lucene/codecs/lucene912/gen_ForUtil.py": "d7850f37e52a16c6592322950d0f6219cad23a33"
}

View File

@ -286,11 +286,11 @@ public final class ForDeltaUtil {
throws IOException {
switch (bitsPerValue) {
case 1:
decode1(pdu, tmp, longs);
decode1(pdu, longs);
prefixSum8(longs, base);
break;
case 2:
decode2(pdu, tmp, longs);
decode2(pdu, longs);
prefixSum8(longs, base);
break;
case 3:
@ -298,7 +298,7 @@ public final class ForDeltaUtil {
prefixSum8(longs, base);
break;
case 4:
decode4(pdu, tmp, longs);
decode4(pdu, longs);
prefixSum8(longs, base);
break;
case 5:
@ -314,7 +314,7 @@ public final class ForDeltaUtil {
prefixSum16(longs, base);
break;
case 8:
decode8To16(pdu, tmp, longs);
decode8To16(pdu, longs);
prefixSum16(longs, base);
break;
case 9:
@ -346,7 +346,7 @@ public final class ForDeltaUtil {
prefixSum32(longs, base);
break;
case 16:
decode16To32(pdu, tmp, longs);
decode16To32(pdu, longs);
prefixSum32(longs, base);
break;
case 17:
@ -431,8 +431,7 @@ public final class ForDeltaUtil {
}
}
private static void decode8To16(PostingDecodingUtil pdu, long[] tmp, long[] longs)
throws IOException {
private static void decode8To16(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.splitLongs(16, longs, 8, 8, MASK16_8, longs, 16, MASK16_8);
}
@ -522,8 +521,7 @@ public final class ForDeltaUtil {
}
}
private static void decode16To32(PostingDecodingUtil pdu, long[] tmp, long[] longs)
throws IOException {
private static void decode16To32(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.splitLongs(32, longs, 16, 16, MASK32_16, longs, 32, MASK32_16);
}
}

View File

@ -291,11 +291,11 @@ public final class ForUtil {
void decode(int bitsPerValue, PostingDecodingUtil pdu, long[] longs) throws IOException {
switch (bitsPerValue) {
case 1:
decode1(pdu, tmp, longs);
decode1(pdu, longs);
expand8(longs);
break;
case 2:
decode2(pdu, tmp, longs);
decode2(pdu, longs);
expand8(longs);
break;
case 3:
@ -303,7 +303,7 @@ public final class ForUtil {
expand8(longs);
break;
case 4:
decode4(pdu, tmp, longs);
decode4(pdu, longs);
expand8(longs);
break;
case 5:
@ -319,7 +319,7 @@ public final class ForUtil {
expand8(longs);
break;
case 8:
decode8(pdu, tmp, longs);
decode8(pdu, longs);
expand8(longs);
break;
case 9:
@ -351,7 +351,7 @@ public final class ForUtil {
expand16(longs);
break;
case 16:
decode16(pdu, tmp, longs);
decode16(pdu, longs);
expand16(longs);
break;
case 17:
@ -393,11 +393,11 @@ public final class ForUtil {
}
}
static void decode1(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {
static void decode1(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.splitLongs(2, longs, 7, 1, MASK8_1, longs, 14, MASK8_1);
}
static void decode2(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {
static void decode2(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.splitLongs(4, longs, 6, 2, MASK8_2, longs, 12, MASK8_2);
}
@ -413,7 +413,7 @@ public final class ForUtil {
}
}
static void decode4(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {
static void decode4(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.splitLongs(8, longs, 4, 4, MASK8_4, longs, 8, MASK8_4);
}
@ -457,7 +457,7 @@ public final class ForUtil {
}
}
static void decode8(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {
static void decode8(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.in.readLongs(longs, 0, 16);
}
@ -601,7 +601,7 @@ public final class ForUtil {
}
}
static void decode16(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {
static void decode16(PostingDecodingUtil pdu, long[] longs) throws IOException {
pdu.in.readLongs(longs, 0, 32);
}

View File

@ -427,7 +427,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
public PostingsEnum reset(IntBlockTermState termState, int flags) throws IOException {
resetIndexInput(termState);
if (pforUtil == null && docFreq >= BLOCK_SIZE) {
pforUtil = new PForUtil(new ForUtil());
pforUtil = new PForUtil();
forDeltaUtil = new ForDeltaUtil();
}
totalTermFreq = indexHasFreq ? termState.totalTermFreq : docFreq;
@ -727,7 +727,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
}
totalTermFreq = termState.totalTermFreq;
if (pforUtil == null && totalTermFreq >= BLOCK_SIZE) {
pforUtil = new PForUtil(new ForUtil());
pforUtil = new PForUtil();
}
// Where this term's postings start in the .pos file:
final long posTermStartFP = termState.posStartFP;
@ -1142,7 +1142,7 @@ public final class Lucene912PostingsReader extends PostingsReaderBase {
private abstract class BlockImpactsEnum extends ImpactsEnum {
protected final ForDeltaUtil forDeltaUtil = new ForDeltaUtil();
protected final PForUtil pforUtil = new PForUtil(new ForUtil());
protected final PForUtil pforUtil = new PForUtil();
protected final long[] docBuffer = new long[BLOCK_SIZE + 1];
protected final long[] freqBuffer = new long[BLOCK_SIZE];

View File

@ -142,9 +142,8 @@ public class Lucene912PostingsWriter extends PushPostingsWriterBase {
metaOut, META_CODEC, VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix);
CodecUtil.writeIndexHeader(
docOut, DOC_CODEC, VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix);
final ForUtil forUtil = new ForUtil();
forDeltaUtil = new ForDeltaUtil();
pforUtil = new PForUtil(forUtil);
pforUtil = new PForUtil();
if (state.fieldInfos.hasProx()) {
posDeltaBuffer = new long[BLOCK_SIZE];
String posFileName =

View File

@ -38,11 +38,10 @@ final class PForUtil {
return true;
}
private final ForUtil forUtil;
private final ForUtil forUtil = new ForUtil();
PForUtil(ForUtil forUtil) {
static {
assert ForUtil.BLOCK_SIZE <= 256 : "blocksize must fit in one byte. got " + ForUtil.BLOCK_SIZE;
this.forUtil = forUtil;
}
/** Encode 128 integers from {@code longs} into {@code out}. */
@ -106,17 +105,18 @@ final class PForUtil {
/** Decode 128 integers into {@code ints}. */
void decode(PostingDecodingUtil pdu, long[] longs) throws IOException {
final int token = Byte.toUnsignedInt(pdu.in.readByte());
var in = pdu.in;
final int token = Byte.toUnsignedInt(in.readByte());
final int bitsPerValue = token & 0x1f;
final int numExceptions = token >>> 5;
if (bitsPerValue == 0) {
Arrays.fill(longs, 0, ForUtil.BLOCK_SIZE, pdu.in.readVLong());
Arrays.fill(longs, 0, ForUtil.BLOCK_SIZE, in.readVLong());
} else {
forUtil.decode(bitsPerValue, pdu, longs);
}
final int numExceptions = token >>> 5;
for (int i = 0; i < numExceptions; ++i) {
longs[Byte.toUnsignedInt(pdu.in.readByte())] |=
Byte.toUnsignedLong(pdu.in.readByte()) << bitsPerValue;
longs[Byte.toUnsignedInt(in.readByte())] |=
Byte.toUnsignedLong(in.readByte()) << bitsPerValue;
}
}

View File

@ -361,7 +361,10 @@ def writeRemainder(bpv, next_primitive, remaining_bits_per_long, o, num_values,
def writeDecode(bpv, f):
next_primitive = primitive_size_for_bpv(bpv)
f.write(' private static void decode%dTo%d(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {\n' %(bpv, next_primitive))
if next_primitive % bpv == 0:
f.write(' private static void decode%dTo%d(PostingDecodingUtil pdu, long[] longs) throws IOException {\n' %(bpv, next_primitive))
else:
f.write(' private static void decode%dTo%d(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {\n' %(bpv, next_primitive))
if bpv == next_primitive:
f.write(' pdu.in.readLongs(longs, 0, %d);\n' %(bpv*2))
else:
@ -390,9 +393,15 @@ if __name__ == '__main__':
primitive_size = primitive_size_for_bpv(bpv)
f.write(' case %d:\n' %bpv)
if next_primitive(bpv) == primitive_size:
f.write(' decode%d(pdu, tmp, longs);\n' %bpv)
if primitive_size % bpv == 0:
f.write(' decode%d(pdu, longs);\n' %bpv)
else:
f.write(' decode%d(pdu, tmp, longs);\n' %bpv)
else:
f.write(' decode%dTo%d(pdu, tmp, longs);\n' %(bpv, primitive_size))
if primitive_size % bpv == 0:
f.write(' decode%dTo%d(pdu, longs);\n' %(bpv, primitive_size))
else:
f.write(' decode%dTo%d(pdu, tmp, longs);\n' %(bpv, primitive_size))
f.write(' prefixSum%d(longs, base);\n' %primitive_size)
f.write(' break;\n')
f.write(' default:\n')

View File

@ -287,8 +287,8 @@ def writeDecode(bpv, f):
next_primitive = 8
elif bpv <= 16:
next_primitive = 16
f.write(' static void decode%d(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {\n' %bpv)
if bpv == next_primitive:
f.write(' static void decode%d(PostingDecodingUtil pdu, long[] longs) throws IOException {\n' %bpv)
f.write(' pdu.in.readLongs(longs, 0, %d);\n' %(bpv*2))
else:
num_values_per_long = 64 / next_primitive
@ -296,8 +296,10 @@ def writeDecode(bpv, f):
num_iters = (next_primitive - 1) // bpv
o = 2 * bpv * num_iters
if remaining_bits == 0:
f.write(' static void decode%d(PostingDecodingUtil pdu, long[] longs) throws IOException {\n' %bpv)
f.write(' pdu.splitLongs(%d, longs, %d, %d, MASK%d_%d, longs, %d, MASK%d_%d);\n' %(bpv*2, next_primitive - bpv, bpv, next_primitive, bpv, o, next_primitive, next_primitive - num_iters * bpv))
else:
f.write(' static void decode%d(PostingDecodingUtil pdu, long[] tmp, long[] longs) throws IOException {\n' %bpv)
f.write(' pdu.splitLongs(%d, longs, %d, %d, MASK%d_%d, tmp, 0, MASK%d_%d);\n' %(bpv*2, next_primitive - bpv, bpv, next_primitive, bpv, next_primitive, next_primitive - num_iters * bpv))
writeRemainder(bpv, next_primitive, remaining_bits, o, 128/num_values_per_long - o, f)
f.write(' }\n')
@ -334,7 +336,10 @@ if __name__ == '__main__':
elif bpv <= 16:
next_primitive = 16
f.write(' case %d:\n' %bpv)
f.write(' decode%d(pdu, tmp, longs);\n' %bpv)
if next_primitive % bpv == 0:
f.write(' decode%d(pdu, longs);\n' %bpv)
else:
f.write(' decode%d(pdu, tmp, longs);\n' %bpv)
f.write(' expand%d(longs);\n' %next_primitive)
f.write(' break;\n')
f.write(' default:\n')

View File

@ -39,11 +39,10 @@ public class TestPForUtil extends LuceneTestCase {
final Directory d = new ByteBuffersDirectory();
final long endPointer = encodeTestData(iterations, values, d);
ForUtil forUtil = new ForUtil();
IndexInput in = d.openInput("test.bin", IOContext.READONCE);
PostingDecodingUtil pdu =
Lucene912PostingsReader.VECTORIZATION_PROVIDER.newPostingDecodingUtil(in);
final PForUtil pforUtil = new PForUtil(forUtil);
final PForUtil pforUtil = new PForUtil();
for (int i = 0; i < iterations; ++i) {
if (random().nextInt(5) == 0) {
PForUtil.skip(in);
@ -91,7 +90,7 @@ public class TestPForUtil extends LuceneTestCase {
private long encodeTestData(int iterations, int[] values, Directory d) throws IOException {
IndexOutput out = d.createOutput("test.bin", IOContext.DEFAULT);
final PForUtil pforUtil = new PForUtil(new ForUtil());
final PForUtil pforUtil = new PForUtil();
for (int i = 0; i < iterations; ++i) {
long[] source = new long[ForUtil.BLOCK_SIZE];