Made encoding conditional based on architecture

This commit is contained in:
expani 2024-10-16 14:33:03 +05:30
parent 33fd619dd0
commit 295531c9a3
2 changed files with 77 additions and 55 deletions

View File

@ -61,6 +61,8 @@ import org.openjdk.jmh.annotations.Warmup;
@Fork(value = 1)
public class DocIdEncodingBenchmark {
private static final long BPV_21_MASK = 0x1FFFFFL;
private static List<int[]> DOC_ID_SEQUENCES = new ArrayList<>();
private static int INPUT_SCALE_FACTOR;
@ -291,15 +293,19 @@ public class DocIdEncodingBenchmark {
}
}
/**
* Uses 21 bits to represent an integer and can store 3 docIds within a long. This is the
* simplified version which is faster in encoding in aarch64
*/
class Bit21With2StepsEncoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i = 0;
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
((docIds[i] & BPV_21_MASK) << 42)
| ((docIds[i + 1] & BPV_21_MASK) << 21)
| (docIds[i + 2] & BPV_21_MASK);
out.writeLong(packedLong);
}
for (; i < count; i++) {
@ -313,8 +319,8 @@ public class DocIdEncodingBenchmark {
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
docIDs[i + 1] = (int) ((packedLong >>> 21) & BPV_21_MASK);
docIDs[i + 2] = (int) (packedLong & BPV_21_MASK);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();
@ -324,7 +330,8 @@ public class DocIdEncodingBenchmark {
/**
* Variation of @{@link Bit21With2StepsEncoder} but uses 3 loops to decode the array of DocIds.
* Comparatively better than @{@link Bit21With2StepsEncoder} on aarch64 with JDK 22
* Comparatively better in decoding than @{@link Bit21With2StepsEncoder} on aarch64 with JDK 22
* whereas poorer in encoding.
*/
class Bit21With3StepsEncoder implements DocIdEncoder {
@ -333,26 +340,26 @@ public class DocIdEncodingBenchmark {
int i = 0;
for (; i < count - 8; i += 9) {
long l1 =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
((docIds[i] & BPV_21_MASK) << 42)
| ((docIds[i + 1] & BPV_21_MASK) << 21)
| (docIds[i + 2] & BPV_21_MASK);
long l2 =
((docIds[i + 3] & 0x001FFFFFL) << 42)
| ((docIds[i + 4] & 0x001FFFFFL) << 21)
| (docIds[i + 5] & 0x001FFFFFL);
((docIds[i + 3] & BPV_21_MASK) << 42)
| ((docIds[i + 4] & BPV_21_MASK) << 21)
| (docIds[i + 5] & BPV_21_MASK);
long l3 =
((docIds[i + 6] & 0x001FFFFFL) << 42)
| ((docIds[i + 7] & 0x001FFFFFL) << 21)
| (docIds[i + 8] & 0x001FFFFFL);
((docIds[i + 6] & BPV_21_MASK) << 42)
| ((docIds[i + 7] & BPV_21_MASK) << 21)
| (docIds[i + 8] & BPV_21_MASK);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
}
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
((docIds[i] & BPV_21_MASK) << 42)
| ((docIds[i + 1] & BPV_21_MASK) << 21)
| (docIds[i + 2] & BPV_21_MASK);
out.writeLong(packedLong);
}
for (; i < count; i++) {
@ -368,20 +375,20 @@ public class DocIdEncodingBenchmark {
long l2 = in.readLong();
long l3 = in.readLong();
docIDs[i] = (int) (l1 >>> 42);
docIDs[i + 1] = (int) ((l1 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (l1 & 0x001FFFFFL);
docIDs[i + 1] = (int) ((l1 >>> 21) & BPV_21_MASK);
docIDs[i + 2] = (int) (l1 & BPV_21_MASK);
docIDs[i + 3] = (int) (l2 >>> 42);
docIDs[i + 4] = (int) ((l2 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 5] = (int) (l2 & 0x001FFFFFL);
docIDs[i + 4] = (int) ((l2 >>> 21) & BPV_21_MASK);
docIDs[i + 5] = (int) (l2 & BPV_21_MASK);
docIDs[i + 6] = (int) (l3 >>> 42);
docIDs[i + 7] = (int) ((l3 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 8] = (int) (l3 & 0x001FFFFFL);
docIDs[i + 7] = (int) ((l3 >>> 21) & BPV_21_MASK);
docIDs[i + 8] = (int) (l3 & BPV_21_MASK);
}
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
docIDs[i + 1] = (int) ((packedLong >>> 21) & BPV_21_MASK);
docIDs[i + 2] = (int) (packedLong & BPV_21_MASK);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();
@ -518,9 +525,8 @@ public class DocIdEncodingBenchmark {
INPUT_SCALE_FACTOR = 10;
}
String inputFilePath = System.getProperty("docIdEncoding.inputFile");
try {
String inputFilePath = System.getProperty("docIdEncoding.inputFile");
if (inputFilePath != null && !inputFilePath.isEmpty()) {
DOC_ID_SEQUENCES =
new DocIdsFromLocalFS()

View File

@ -23,6 +23,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.DocBaseBitSetIterator;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.IntsRef;
@ -39,6 +40,10 @@ final class DocIdsWriter {
// These signs are legacy, should no longer be used in the writing side.
private static final byte LEGACY_DELTA_VINT = (byte) 0;
private static final long BPV_21_MASK = 0x1FFFFFL;
private static final boolean IS_ARCH_64 = Constants.OS_ARCH.equals("aarch64");
private final int[] scratch;
private final LongsRef scratchLongs = new LongsRef();
@ -116,28 +121,32 @@ final class DocIdsWriter {
if (max <= 0x001FFFFF) {
out.writeByte(BPV_21);
int i = 0;
for (; i < count - 8; i += 9) {
long l1 =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
long l2 =
((docIds[i + 3] & 0x001FFFFFL) << 42)
| ((docIds[i + 4] & 0x001FFFFFL) << 21)
| (docIds[i + 5] & 0x001FFFFFL);
long l3 =
((docIds[i + 6] & 0x001FFFFFL) << 42)
| ((docIds[i + 7] & 0x001FFFFFL) << 21)
| (docIds[i + 8] & 0x001FFFFFL);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
// See
// @org.apache.lucene.benchmark.jmh.DocIdEncodingBenchmark.DocIdEncoder.Bit21With3StepsEncoder
if (!IS_ARCH_64) {
for (; i < count - 8; i += 9) {
long l1 =
((docIds[i] & BPV_21_MASK) << 42)
| ((docIds[i + 1] & BPV_21_MASK) << 21)
| (docIds[i + 2] & BPV_21_MASK);
long l2 =
((docIds[i + 3] & BPV_21_MASK) << 42)
| ((docIds[i + 4] & BPV_21_MASK) << 21)
| (docIds[i + 5] & BPV_21_MASK);
long l3 =
((docIds[i + 6] & BPV_21_MASK) << 42)
| ((docIds[i + 7] & BPV_21_MASK) << 21)
| (docIds[i + 8] & BPV_21_MASK);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
}
}
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
((docIds[i] & BPV_21_MASK) << 42)
| ((docIds[i + 1] & BPV_21_MASK) << 21)
| (docIds[i + 2] & BPV_21_MASK);
out.writeLong(packedLong);
}
for (; i < count; i++) {
@ -298,25 +307,32 @@ final class DocIdsWriter {
private void readInts21(IndexInput in, int count, int[] docIDs) throws IOException {
int i = 0;
// We are always using
// org.apache.lucene.benchmark.jmh.DocIdEncodingBenchmark.DocIdEncoder.Bit21With3StepsEncoder
// over
// org.apache.lucene.benchmark.jmh.DocIdEncodingBenchmark.DocIdEncoder.Bit21With2StepsEncoder
// for decoding irrespective of architecture
// due to it's better performance in benchmarks over nyc taxis, big5, http_logs and other
// popular search workloads.
for (; i < count - 8; i += 9) {
long l1 = in.readLong();
long l2 = in.readLong();
long l3 = in.readLong();
docIDs[i] = (int) (l1 >>> 42);
docIDs[i + 1] = (int) ((l1 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (l1 & 0x001FFFFFL);
docIDs[i + 1] = (int) ((l1 >>> 21) & BPV_21_MASK);
docIDs[i + 2] = (int) (l1 & BPV_21_MASK);
docIDs[i + 3] = (int) (l2 >>> 42);
docIDs[i + 4] = (int) ((l2 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 5] = (int) (l2 & 0x001FFFFFL);
docIDs[i + 4] = (int) ((l2 >>> 21) & BPV_21_MASK);
docIDs[i + 5] = (int) (l2 & BPV_21_MASK);
docIDs[i + 6] = (int) (l3 >>> 42);
docIDs[i + 7] = (int) ((l3 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 8] = (int) (l3 & 0x001FFFFFL);
docIDs[i + 7] = (int) ((l3 >>> 21) & BPV_21_MASK);
docIDs[i + 8] = (int) (l3 & BPV_21_MASK);
}
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
docIDs[i + 1] = (int) ((packedLong >>> 21) & BPV_21_MASK);
docIDs[i + 2] = (int) (packedLong & BPV_21_MASK);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();