Used reflection for generating singleton instances and addressed PR Comments

This commit is contained in:
expani 2024-10-06 17:00:17 +05:30
parent e309031c7f
commit 7dd5d80774
1 changed files with 276 additions and 219 deletions

View File

@ -16,24 +16,27 @@
*/
package org.apache.lucene.benchmark.jmh;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Scanner;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.NIOFSDirectory;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
@ -56,7 +59,7 @@ import org.openjdk.jmh.annotations.Warmup;
@Fork(value = 1)
public class DocIdEncodingBenchmark {
private static final List<int[]> DOC_ID_SEQUENCES = new ArrayList<>();
private static List<int[]> DOC_ID_SEQUENCES = new ArrayList<>();
private static int INPUT_SCALE_FACTOR;
@ -80,19 +83,19 @@ public class DocIdEncodingBenchmark {
private final int[] scratch = new int[512];
private String decoderInputFile;
@Setup(Level.Trial)
public void init() throws IOException {
tmpDir = Files.createTempDirectory("docIdJmh");
docIdEncoder = DocIdEncoder.SingletonFactory.fromName(encoderName);
// Create file once for decoders to read from in every iteration
decoderInputFile =
String.join("_", "docIdJmhData", docIdEncoder.getClass().getSimpleName(), "DecoderInput");
// Create a file for decoders ( once per trial ) to read in every JMH iteration
if (methodName.equalsIgnoreCase("decode")) {
String dataFile =
String.join("_", "docIdJmhData", docIdEncoder.getClass().getSimpleName(), "DecoderInput");
try (Directory dir = new NIOFSDirectory(tmpDir)) {
out = dir.createOutput(dataFile, IOContext.DEFAULT);
encode();
} finally {
out.close();
try (Directory dir = FSDirectory.open(tmpDir);
IndexOutput out = dir.createOutput(decoderInputFile, IOContext.DEFAULT)) {
encode(out, docIdEncoder, DOC_ID_SEQUENCES, INPUT_SCALE_FACTOR);
}
}
}
@ -100,9 +103,7 @@ public class DocIdEncodingBenchmark {
@TearDown(Level.Trial)
public void finish() throws IOException {
if (methodName.equalsIgnoreCase("decode")) {
String dataFile =
String.join("_", "docIdJmhData", docIdEncoder.getClass().getSimpleName(), "DecoderInput");
Files.delete(tmpDir.resolve(dataFile));
Files.delete(tmpDir.resolve(decoderInputFile));
}
Files.delete(tmpDir);
}
@ -110,45 +111,50 @@ public class DocIdEncodingBenchmark {
@Benchmark
public void executeEncodeOrDecode() throws IOException {
if (methodName.equalsIgnoreCase("encode")) {
String dataFile =
String outputFile =
String.join(
"_",
"docIdJmhData",
docIdEncoder.getClass().getSimpleName(),
String.valueOf(System.nanoTime()));
try (Directory dir = new NIOFSDirectory(tmpDir)) {
out = dir.createOutput(dataFile, IOContext.DEFAULT);
encode();
try (Directory dir = FSDirectory.open(tmpDir);
IndexOutput out = dir.createOutput(outputFile, IOContext.DEFAULT)) {
encode(out, docIdEncoder, DOC_ID_SEQUENCES, INPUT_SCALE_FACTOR);
} finally {
Files.delete(tmpDir.resolve(dataFile));
out.close();
Files.delete(tmpDir.resolve(outputFile));
}
} else if (methodName.equalsIgnoreCase("decode")) {
String inputFile =
String.join("_", "docIdJmhData", docIdEncoder.getClass().getSimpleName(), "DecoderInput");
try (Directory dir = new NIOFSDirectory(tmpDir)) {
in = dir.openInput(inputFile, IOContext.DEFAULT);
decode();
} finally {
in.close();
try (Directory dir = FSDirectory.open(tmpDir)) {
in = dir.openInput(decoderInputFile, IOContext.DEFAULT);
decode(in, docIdEncoder, DOC_ID_SEQUENCES, INPUT_SCALE_FACTOR, scratch);
}
} else {
throw new IllegalArgumentException("Unknown method: " + methodName);
}
}
public void encode() throws IOException {
for (int[] docIdSequence : DOC_ID_SEQUENCES) {
for (int i = 1; i <= INPUT_SCALE_FACTOR; i++) {
public void encode(
IndexOutput out, DocIdEncoder docIdEncoder, List<int[]> docIdSequences, int inputScaleFactor)
throws IOException {
for (int[] docIdSequence : docIdSequences) {
for (int i = 1; i <= inputScaleFactor; i++) {
docIdEncoder.encode(out, 0, docIdSequence.length, docIdSequence);
}
}
}
public void decode() throws IOException {
for (int[] docIdSequence : DOC_ID_SEQUENCES) {
for (int i = 1; i <= INPUT_SCALE_FACTOR; i++) {
public void decode(
IndexInput in,
DocIdEncoder docIdEncoder,
List<int[]> docIdSequences,
int inputScaleFactor,
int[] scratch)
throws IOException {
for (int[] docIdSequence : docIdSequences) {
for (int i = 1; i <= inputScaleFactor; i++) {
docIdEncoder.decode(in, 0, docIdSequence.length, scratch);
// TODO Use a unit test with a DocIdProvider that generates a few random sequences based on
// given BPV.
// Uncomment to test the output of Encoder
// if (!Arrays.equals(
// docIdSequence, Arrays.copyOfRange(scratch, 0, docIdSequence.length)))
@ -175,16 +181,27 @@ public class DocIdEncodingBenchmark {
class SingletonFactory {
static final Map<String, DocIdEncoder> ENCODER_NAME_TO_INSTANCE_MAPPING =
Map.of(
Bit24Encoder.class.getSimpleName().toLowerCase(Locale.ROOT),
new Bit24Encoder(),
Bit21With2StepsEncoder.class.getSimpleName().toLowerCase(Locale.ROOT),
new Bit21With2StepsEncoder(),
Bit21With3StepsEncoder.class.getSimpleName().toLowerCase(Locale.ROOT),
new Bit21With3StepsEncoder(),
Bit32Encoder.class.getSimpleName().toLowerCase(Locale.ROOT),
new Bit32Encoder());
static final Map<String, DocIdEncoder> ENCODER_NAME_TO_INSTANCE_MAPPING = new HashMap<>();
static {
Class<?>[] allImplementations = DocIdEncoder.class.getDeclaredClasses();
for (Class<?> clazz : allImplementations) {
boolean isADocIdEncoder =
Arrays.asList(clazz.getInterfaces()).contains(DocIdEncoder.class);
if (isADocIdEncoder) {
try {
ENCODER_NAME_TO_INSTANCE_MAPPING.put(
clazz.getSimpleName().toLowerCase(Locale.ROOT),
(DocIdEncoder) clazz.getConstructor().newInstance());
} catch (InstantiationException
| IllegalAccessException
| InvocationTargetException
| NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
}
}
public static DocIdEncoder fromName(String encoderName) {
String parsedEncoderName = encoderName.trim().toLowerCase(Locale.ROOT);
@ -195,209 +212,249 @@ public class DocIdEncodingBenchmark {
}
}
}
}
static class Bit24Encoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i;
for (i = 0; i < count - 7; i += 8) {
int doc1 = docIds[i];
int doc2 = docIds[i + 1];
int doc3 = docIds[i + 2];
int doc4 = docIds[i + 3];
int doc5 = docIds[i + 4];
int doc6 = docIds[i + 5];
int doc7 = docIds[i + 6];
int doc8 = docIds[i + 7];
long l1 = (doc1 & 0xffffffL) << 40 | (doc2 & 0xffffffL) << 16 | ((doc3 >>> 8) & 0xffffL);
long l2 =
(doc3 & 0xffL) << 56
| (doc4 & 0xffffffL) << 32
| (doc5 & 0xffffffL) << 8
| ((doc6 >> 16) & 0xffL);
long l3 = (doc6 & 0xffffL) << 48 | (doc7 & 0xffffffL) << 24 | (doc8 & 0xffffffL);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
class Bit24Encoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i;
for (i = 0; i < count - 7; i += 8) {
int doc1 = docIds[i];
int doc2 = docIds[i + 1];
int doc3 = docIds[i + 2];
int doc4 = docIds[i + 3];
int doc5 = docIds[i + 4];
int doc6 = docIds[i + 5];
int doc7 = docIds[i + 6];
int doc8 = docIds[i + 7];
long l1 = (doc1 & 0xffffffL) << 40 | (doc2 & 0xffffffL) << 16 | ((doc3 >>> 8) & 0xffffL);
long l2 =
(doc3 & 0xffL) << 56
| (doc4 & 0xffffffL) << 32
| (doc5 & 0xffffffL) << 8
| ((doc6 >> 16) & 0xffL);
long l3 = (doc6 & 0xffffL) << 48 | (doc7 & 0xffffffL) << 24 | (doc8 & 0xffffffL);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
}
for (; i < count; ++i) {
out.writeShort((short) (docIds[i] >>> 8));
out.writeByte((byte) docIds[i]);
}
}
for (; i < count; ++i) {
out.writeShort((short) (docIds[i] >>> 8));
out.writeByte((byte) docIds[i]);
@Override
public void decode(IndexInput in, int start, int count, int[] docIDs) throws IOException {
int i;
for (i = 0; i < count - 7; i += 8) {
long l1 = in.readLong();
long l2 = in.readLong();
long l3 = in.readLong();
docIDs[i] = (int) (l1 >>> 40);
docIDs[i + 1] = (int) (l1 >>> 16) & 0xffffff;
docIDs[i + 2] = (int) (((l1 & 0xffff) << 8) | (l2 >>> 56));
docIDs[i + 3] = (int) (l2 >>> 32) & 0xffffff;
docIDs[i + 4] = (int) (l2 >>> 8) & 0xffffff;
docIDs[i + 5] = (int) (((l2 & 0xff) << 16) | (l3 >>> 48));
docIDs[i + 6] = (int) (l3 >>> 24) & 0xffffff;
docIDs[i + 7] = (int) l3 & 0xffffff;
}
for (; i < count; ++i) {
docIDs[i] =
(Short.toUnsignedInt(in.readShort()) << 8) | Byte.toUnsignedInt(in.readByte());
}
}
}
@Override
public void decode(IndexInput in, int start, int count, int[] docIDs) throws IOException {
int i;
for (i = 0; i < count - 7; i += 8) {
long l1 = in.readLong();
long l2 = in.readLong();
long l3 = in.readLong();
docIDs[i] = (int) (l1 >>> 40);
docIDs[i + 1] = (int) (l1 >>> 16) & 0xffffff;
docIDs[i + 2] = (int) (((l1 & 0xffff) << 8) | (l2 >>> 56));
docIDs[i + 3] = (int) (l2 >>> 32) & 0xffffff;
docIDs[i + 4] = (int) (l2 >>> 8) & 0xffffff;
docIDs[i + 5] = (int) (((l2 & 0xff) << 16) | (l3 >>> 48));
docIDs[i + 6] = (int) (l3 >>> 24) & 0xffffff;
docIDs[i + 7] = (int) l3 & 0xffffff;
class Bit21With2StepsEncoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i = 0;
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
out.writeLong(packedLong);
}
for (; i < count; i++) {
out.writeInt(docIds[i]);
}
}
for (; i < count; ++i) {
docIDs[i] = (Short.toUnsignedInt(in.readShort()) << 8) | Byte.toUnsignedInt(in.readByte());
@Override
public void decode(IndexInput in, int start, int count, int[] docIDs) throws IOException {
int i = 0;
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();
}
}
}
/**
* Variation of @{@link Bit21With2StepsEncoder} but uses 3 loops to decode the array of DocIds.
* Comparatively better than @{@link Bit21With2StepsEncoder} on aarch64 with JDK 22
*/
class Bit21With3StepsEncoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i = 0;
for (; i < count - 8; i += 9) {
long l1 =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
long l2 =
((docIds[i + 3] & 0x001FFFFFL) << 42)
| ((docIds[i + 4] & 0x001FFFFFL) << 21)
| (docIds[i + 5] & 0x001FFFFFL);
long l3 =
((docIds[i + 6] & 0x001FFFFFL) << 42)
| ((docIds[i + 7] & 0x001FFFFFL) << 21)
| (docIds[i + 8] & 0x001FFFFFL);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
}
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
out.writeLong(packedLong);
}
for (; i < count; i++) {
out.writeInt(docIds[i]);
}
}
@Override
public void decode(IndexInput in, int start, int count, int[] docIDs) throws IOException {
int i = 0;
for (; i < count - 8; i += 9) {
long l1 = in.readLong();
long l2 = in.readLong();
long l3 = in.readLong();
docIDs[i] = (int) (l1 >>> 42);
docIDs[i + 1] = (int) ((l1 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (l1 & 0x001FFFFFL);
docIDs[i + 3] = (int) (l2 >>> 42);
docIDs[i + 4] = (int) ((l2 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 5] = (int) (l2 & 0x001FFFFFL);
docIDs[i + 6] = (int) (l3 >>> 42);
docIDs[i + 7] = (int) ((l3 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 8] = (int) (l3 & 0x001FFFFFL);
}
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();
}
}
}
class Bit32Encoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
for (int i = 0; i < count; i++) {
out.writeInt(docIds[i]);
}
}
@Override
public void decode(IndexInput in, int start, int count, int[] docIds) throws IOException {
for (int i = 0; i < count; i++) {
docIds[i] = in.readInt();
}
}
}
}
static class Bit21With2StepsEncoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i = 0;
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
out.writeLong(packedLong);
}
for (; i < count; i++) {
out.writeInt(docIds[i]);
}
}
@Override
public void decode(IndexInput in, int start, int count, int[] docIDs) throws IOException {
int i = 0;
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();
}
}
interface DocIdProvider {
/**
* We want to load all the docId sequences completely in memory to avoid including the time
* spent in fetching from disk. <br>
*
* @return: All the docId sequences or empty list.
*/
List<int[]> getDocIds(Object... args);
}
/**
* Variation of @{@link Bit21With2StepsEncoder} but uses 3 loops to decode the array of DocIds.
* Comparatively better than @{@link Bit21With2StepsEncoder} on aarch64 with JDK 22
*/
static class Bit21With3StepsEncoder implements DocIdEncoder {
static class DocIdsFromLocalFS implements DocIdProvider {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
int i = 0;
for (; i < count - 8; i += 9) {
long l1 =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
long l2 =
((docIds[i + 3] & 0x001FFFFFL) << 42)
| ((docIds[i + 4] & 0x001FFFFFL) << 21)
| (docIds[i + 5] & 0x001FFFFFL);
long l3 =
((docIds[i + 6] & 0x001FFFFFL) << 42)
| ((docIds[i + 7] & 0x001FFFFFL) << 21)
| (docIds[i + 8] & 0x001FFFFFL);
out.writeLong(l1);
out.writeLong(l2);
out.writeLong(l3);
}
for (; i < count - 2; i += 3) {
long packedLong =
((docIds[i] & 0x001FFFFFL) << 42)
| ((docIds[i + 1] & 0x001FFFFFL) << 21)
| (docIds[i + 2] & 0x001FFFFFL);
out.writeLong(packedLong);
}
for (; i < count; i++) {
out.writeInt(docIds[i]);
}
}
@Override
public void decode(IndexInput in, int start, int count, int[] docIDs) throws IOException {
int i = 0;
for (; i < count - 8; i += 9) {
long l1 = in.readLong();
long l2 = in.readLong();
long l3 = in.readLong();
docIDs[i] = (int) (l1 >>> 42);
docIDs[i + 1] = (int) ((l1 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (l1 & 0x001FFFFFL);
docIDs[i + 3] = (int) (l2 >>> 42);
docIDs[i + 4] = (int) ((l2 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 5] = (int) (l2 & 0x001FFFFFL);
docIDs[i + 6] = (int) (l3 >>> 42);
docIDs[i + 7] = (int) ((l3 & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 8] = (int) (l3 & 0x001FFFFFL);
}
for (; i < count - 2; i += 3) {
long packedLong = in.readLong();
docIDs[i] = (int) (packedLong >>> 42);
docIDs[i + 1] = (int) ((packedLong & 0x000003FFFFE00000L) >>> 21);
docIDs[i + 2] = (int) (packedLong & 0x001FFFFFL);
}
for (; i < count; i++) {
docIDs[i] = in.readInt();
}
}
}
static class Bit32Encoder implements DocIdEncoder {
@Override
public void encode(IndexOutput out, int start, int count, int[] docIds) throws IOException {
for (int i = 0; i < count; i++) {
out.writeInt(docIds[i]);
}
}
@Override
public void decode(IndexInput in, int start, int count, int[] docIds) throws IOException {
for (int i = 0; i < count; i++) {
docIds[i] = in.readInt();
public List<int[]> getDocIds(Object... args) {
List<int[]> docIds = new ArrayList<>();
InputStream fileContents = (InputStream) args[0];
try (Scanner fileReader = new Scanner(fileContents, Charset.defaultCharset())) {
while (fileReader.hasNextLine()) {
String sequence = fileReader.nextLine().trim();
if (!sequence.startsWith("#") && !sequence.isEmpty()) {
docIds.add(
Arrays.stream(sequence.split(","))
.map(String::trim)
.mapToInt(Integer::parseInt)
.toArray());
}
}
}
return docIds;
}
}
private static void parseInput() {
String inputScaleFactor = System.getProperty("docIdEncoding.inputScaleFactor");
if (inputScaleFactor != null) {
if (inputScaleFactor != null && !inputScaleFactor.isEmpty()) {
INPUT_SCALE_FACTOR = Integer.parseInt(inputScaleFactor);
} else {
INPUT_SCALE_FACTOR = 2_00_000;
}
String inputFilePath = System.getProperty("docIdEncoding.inputFile");
Scanner fileReader = null;
try {
if (inputFilePath != null) {
fileReader = new Scanner(Paths.get(inputFilePath), Charset.defaultCharset());
} else {
fileReader =
new Scanner(
Objects.requireNonNull(
DocIdEncodingBenchmark.class.getResourceAsStream(
"/org.apache.lucene.benchmark.jmh/docIds_bpv21.txt")),
Charset.defaultCharset());
String docProviderFQDN = System.getProperty("docIdEncoding.docIdProviderFQDN");
DocIdProvider docIdProvider = new DocIdsFromLocalFS();
if (docProviderFQDN != null && !docProviderFQDN.isEmpty()) {
try {
docIdProvider =
(DocIdProvider) Class.forName(docProviderFQDN).getConstructor().newInstance();
} catch (InstantiationException
| IllegalAccessException
| InvocationTargetException
| NoSuchMethodException
| ClassNotFoundException e) {
throw new RuntimeException(e);
}
while (fileReader.hasNextLine()) {
String sequence = fileReader.nextLine().trim();
if (!sequence.startsWith("#") && !sequence.isEmpty()) {
DOC_ID_SEQUENCES.add(
Arrays.stream(sequence.split(",")).map(String::trim).mapToInt(Integer::parseInt).toArray());
}
if (docIdProvider instanceof DocIdsFromLocalFS) {
String inputFilePath = System.getProperty("docIdEncoding.inputFile");
try {
if (inputFilePath != null && !inputFilePath.isEmpty()) {
DOC_ID_SEQUENCES = docIdProvider.getDocIds(new FileInputStream(inputFilePath));
} else {
DOC_ID_SEQUENCES =
docIdProvider.getDocIds(
DocIdEncodingBenchmark.class.getResourceAsStream(
"/org.apache.lucene.benchmark.jmh/docIds_bpv21.txt"));
}
}
} catch (IOException e) {
throw new RuntimeException(e);
} finally {
if (fileReader != null) {
fileReader.close();
} catch (FileNotFoundException e) {
throw new RuntimeException(e);
}
}
}