Speed up writeGroupVInts (#13203)

* Speed up writeGroupVInts
This commit is contained in:
Zhang Chao 2024-03-26 22:48:30 +08:00 committed by GitHub
parent 26f5065e15
commit 1f909baca5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 59 additions and 29 deletions

View File

@ -235,10 +235,13 @@ Optimizations
* GITHUB#13121: Speedup multi-segment HNSW graph search for diversifying child kNN queries. Builds on GITHUB#12962.
(Ben Trent)
* GITHUB#13184: Make the HitQueue size more appropriate for KNN exact search (Pan Guixin)
* GITHUB#13199: Speed up dynamic pruning by breaking point estimation when threshold get exceeded. (Guo Feng)
* GITHUB#13203: Speed up writeGroupVInts (Zhang Chao)
Bug Fixes
---------------------

View File

@ -86,6 +86,7 @@ public class GroupVIntBenchmark {
};
final int maxSize = 256;
final long[] docs = new long[maxSize];
final long[] values = new long[maxSize];
IndexInput byteBufferGVIntIn;
@ -96,6 +97,9 @@ public class GroupVIntBenchmark {
ByteArrayDataInput byteArrayVIntIn;
ByteArrayDataInput byteArrayGVIntIn;
// benchmark for write
ByteBuffersDataOutput byteBuffersGVIntOut = new ByteBuffersDataOutput();
@Param({"64"})
public int size;
@ -153,7 +157,6 @@ public class GroupVIntBenchmark {
@Setup(Level.Trial)
public void init() throws Exception {
long[] docs = new long[maxSize];
Random r = new Random(0);
for (int i = 0; i < maxSize; ++i) {
float randomFloat = r.nextFloat();
@ -237,4 +240,10 @@ public class GroupVIntBenchmark {
this.readGroupVIntsBaseline(byteBuffersGVIntIn, values, size);
bh.consume(values);
}
@Benchmark
public void bench_writeGroupVInt(Blackhole bh) throws IOException {
byteBuffersGVIntOut.reset();
byteBuffersGVIntOut.writeGroupVInts(docs, size);
}
}

View File

@ -21,7 +21,7 @@ import java.util.Map;
import java.util.Set;
import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.apache.lucene.util.GroupVIntUtil;
/**
* Abstract base class for performing write operations of Lucene's low-level data types.
@ -30,7 +30,7 @@ import org.apache.lucene.util.BytesRefBuilder;
* internal state like file position).
*/
public abstract class DataOutput {
private final BytesRefBuilder groupVIntBytes = new BytesRefBuilder();
private byte[] groupVIntBytes;
/**
* Writes a single byte.
@ -335,32 +335,9 @@ public abstract class DataOutput {
* @lucene.experimental
*/
public void writeGroupVInts(long[] values, int limit) throws IOException {
int off = 0;
// encode each group
while ((limit - off) >= 4) {
byte flag = 0;
groupVIntBytes.setLength(1);
flag |= (encodeGroupValue(Math.toIntExact(values[off++])) - 1) << 6;
flag |= (encodeGroupValue(Math.toIntExact(values[off++])) - 1) << 4;
flag |= (encodeGroupValue(Math.toIntExact(values[off++])) - 1) << 2;
flag |= (encodeGroupValue(Math.toIntExact(values[off++])) - 1);
groupVIntBytes.setByteAt(0, flag);
writeBytes(groupVIntBytes.bytes(), groupVIntBytes.length());
if (groupVIntBytes == null) {
groupVIntBytes = new byte[GroupVIntUtil.MAX_LENGTH_PER_GROUP];
}
// tail vints
for (; off < limit; off++) {
writeVInt(Math.toIntExact(values[off]));
}
}
private int encodeGroupValue(int v) {
int lastOff = groupVIntBytes.length();
do {
groupVIntBytes.append((byte) (v & 0xFF));
v >>>= 8;
} while (v != 0);
return groupVIntBytes.length() - lastOff;
GroupVIntUtil.writeGroupVInts(this, groupVIntBytes, values, limit);
}
}

View File

@ -18,6 +18,7 @@ package org.apache.lucene.util;
import java.io.IOException;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput;
/**
* This class contains utility methods and constants for group varint
@ -111,4 +112,44 @@ public final class GroupVIntUtil {
pos += 1 + n4Minus1;
return (int) (pos - posStart);
}
private static int numBytes(int v) {
// | 1 to return 1 when v = 0
return Integer.BYTES - (Integer.numberOfLeadingZeros(v | 1) >> 3);
}
/**
* The implementation for group-varint encoding, It uses a maximum of {@link
* #MAX_LENGTH_PER_GROUP} bytes scratch buffer.
*/
public static void writeGroupVInts(DataOutput out, byte[] scratch, long[] values, int limit)
throws IOException {
int readPos = 0;
// encode each group
while ((limit - readPos) >= 4) {
int writePos = 0;
final int n1Minus1 = numBytes(Math.toIntExact(values[readPos])) - 1;
final int n2Minus1 = numBytes(Math.toIntExact(values[readPos + 1])) - 1;
final int n3Minus1 = numBytes(Math.toIntExact(values[readPos + 2])) - 1;
final int n4Minus1 = numBytes(Math.toIntExact(values[readPos + 3])) - 1;
int flag = (n1Minus1 << 6) | (n2Minus1 << 4) | (n3Minus1 << 2) | (n4Minus1);
scratch[writePos++] = (byte) flag;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
writePos += n1Minus1 + 1;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
writePos += n2Minus1 + 1;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
writePos += n3Minus1 + 1;
BitUtil.VH_LE_INT.set(scratch, writePos, Math.toIntExact(values[readPos++]));
writePos += n4Minus1 + 1;
out.writeBytes(scratch, writePos);
}
// tail vints
for (; readPos < limit; readPos++) {
out.writeVInt(Math.toIntExact(values[readPos]));
}
}
}