Speed up the sort when building forward index (#12712)

This commit is contained in:
gf2121 2023-10-25 13:36:52 +08:00 committed by GitHub
parent 676dceb081
commit 779592771a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 291 additions and 99 deletions

View File

@ -225,6 +225,8 @@ Optimizations
* GITHUB#12710: Use Arrays#mismatch for Outputs#common operations. (Guo Feng) * GITHUB#12710: Use Arrays#mismatch for Outputs#common operations. (Guo Feng)
* GITHUB#12712: Speed up sorting postings file with an offline radix sorter in BPIndexReader. (Guo Feng)
Changes in runtime behavior Changes in runtime behavior
--------------------- ---------------------

View File

@ -35,7 +35,7 @@ import org.apache.lucene.index.SortingCodecReader;
import org.apache.lucene.index.Terms; import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum; import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.ByteBuffersDataOutput;
import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.DataOutput;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
@ -46,13 +46,11 @@ import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.store.TrackingDirectoryWrapper; import org.apache.lucene.store.TrackingDirectoryWrapper;
import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefComparator;
import org.apache.lucene.util.CloseableThreadLocal; import org.apache.lucene.util.CloseableThreadLocal;
import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.IntroSorter;
import org.apache.lucene.util.IntsRef; import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.OfflineSorter; import org.apache.lucene.util.packed.PackedInts;
import org.apache.lucene.util.OfflineSorter.BufferSize;
/** /**
* Implementation of "recursive graph bisection", also called "bipartite graph partitioning" and * Implementation of "recursive graph bisection", also called "bipartite graph partitioning" and
@ -654,9 +652,7 @@ public final class BPIndexReorderer {
for (int doc = postings.nextDoc(); for (int doc = postings.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS; doc != DocIdSetIterator.NO_MORE_DOCS;
doc = postings.nextDoc()) { doc = postings.nextDoc()) {
// reverse bytes so that byte order matches natural order postingsOut.writeLong(Integer.toUnsignedLong(termID) << 32 | Integer.toUnsignedLong(doc));
postingsOut.writeInt(Integer.reverseBytes(doc));
postingsOut.writeInt(Integer.reverseBytes(termID));
} }
} }
} }
@ -665,80 +661,28 @@ public final class BPIndexReorderer {
private ForwardIndex buildForwardIndex( private ForwardIndex buildForwardIndex(
Directory tempDir, String postingsFileName, int maxDoc, int maxTerm) throws IOException { Directory tempDir, String postingsFileName, int maxDoc, int maxTerm) throws IOException {
String sortedPostingsFile =
new OfflineSorter(
tempDir,
"forward-index",
// Implement BytesRefComparator to make OfflineSorter use radix sort
new BytesRefComparator(2 * Integer.BYTES) {
@Override
protected int byteAt(BytesRef ref, int i) {
return ref.bytes[ref.offset + i] & 0xFF;
}
@Override
public int compare(BytesRef o1, BytesRef o2, int k) {
assert o1.length == 2 * Integer.BYTES;
assert o2.length == 2 * Integer.BYTES;
return ArrayUtil.compareUnsigned8(o1.bytes, o1.offset, o2.bytes, o2.offset);
}
},
BufferSize.megabytes((long) (ramBudgetMB / getParallelism())),
OfflineSorter.MAX_TEMPFILES,
2 * Integer.BYTES,
forkJoinPool,
getParallelism()) {
@Override
protected ByteSequencesReader getReader(ChecksumIndexInput in, String name)
throws IOException {
return new ByteSequencesReader(in, postingsFileName) {
{
ref.grow(2 * Integer.BYTES);
ref.setLength(2 * Integer.BYTES);
}
@Override
public BytesRef next() throws IOException {
if (in.getFilePointer() >= end) {
return null;
}
// optimized read of 8 bytes
in.readBytes(ref.bytes(), 0, 2 * Integer.BYTES);
return ref.get();
}
};
}
@Override
protected ByteSequencesWriter getWriter(IndexOutput out, long itemCount)
throws IOException {
return new ByteSequencesWriter(out) {
@Override
public void write(byte[] bytes, int off, int len) throws IOException {
assert len == 2 * Integer.BYTES;
// optimized read of 8 bytes
out.writeBytes(bytes, off, len);
}
};
}
}.sort(postingsFileName);
String termIDsFileName; String termIDsFileName;
String startOffsetsFileName; String startOffsetsFileName;
int prevDoc = -1; try (IndexOutput termIDs = tempDir.createTempOutput("term-ids", "", IOContext.DEFAULT);
try (IndexInput sortedPostings = tempDir.openInput(sortedPostingsFile, IOContext.READONCE);
IndexOutput termIDs = tempDir.createTempOutput("term-ids", "", IOContext.DEFAULT);
IndexOutput startOffsets = IndexOutput startOffsets =
tempDir.createTempOutput("start-offsets", "", IOContext.DEFAULT)) { tempDir.createTempOutput("start-offsets", "", IOContext.DEFAULT)) {
termIDsFileName = termIDs.getName(); termIDsFileName = termIDs.getName();
startOffsetsFileName = startOffsets.getName(); startOffsetsFileName = startOffsets.getName();
final long end = sortedPostings.length() - CodecUtil.footerLength();
int[] buffer = new int[TERM_IDS_BLOCK_SIZE]; int[] buffer = new int[TERM_IDS_BLOCK_SIZE];
new ForwardIndexSorter(tempDir)
.sortAndConsume(
postingsFileName,
maxDoc,
new LongConsumer() {
int prevDoc = -1;
int bufferLen = 0; int bufferLen = 0;
while (sortedPostings.getFilePointer() < end) {
final int doc = Integer.reverseBytes(sortedPostings.readInt()); @Override
final int termID = Integer.reverseBytes(sortedPostings.readInt()); public void accept(long value) throws IOException {
int doc = (int) value;
int termID = (int) (value >>> 32);
if (doc != prevDoc) { if (doc != prevDoc) {
if (bufferLen != 0) { if (bufferLen != 0) {
writeMonotonicInts(buffer, bufferLen, termIDs); writeMonotonicInts(buffer, bufferLen, termIDs);
@ -758,6 +702,9 @@ public final class BPIndexReorderer {
} }
buffer[bufferLen++] = termID; buffer[bufferLen++] = termID;
} }
@Override
public void onFinish() throws IOException {
if (bufferLen != 0) { if (bufferLen != 0) {
writeMonotonicInts(buffer, bufferLen, termIDs); writeMonotonicInts(buffer, bufferLen, termIDs);
} }
@ -767,6 +714,8 @@ public final class BPIndexReorderer {
CodecUtil.writeFooter(termIDs); CodecUtil.writeFooter(termIDs);
CodecUtil.writeFooter(startOffsets); CodecUtil.writeFooter(startOffsets);
} }
});
}
IndexInput termIDsInput = tempDir.openInput(termIDsFileName, IOContext.READ); IndexInput termIDsInput = tempDir.openInput(termIDsFileName, IOContext.READ);
IndexInput startOffsets = tempDir.openInput(startOffsetsFileName, IOContext.READ); IndexInput startOffsets = tempDir.openInput(startOffsetsFileName, IOContext.READ);
@ -991,4 +940,169 @@ public final class BPIndexReorderer {
} }
return len; return len;
} }
/**
* Use a LSB Radix Sorter to sort the (docID, termID) entries. We only need to compare docIds
* because LSB Radix Sorter is stable and termIDs already sorted.
*
* <p>This sorter will require at least 16MB ({@link #BUFFER_BYTES} * {@link #HISTOGRAM_SIZE})
* RAM.
*/
static class ForwardIndexSorter {
private static final int HISTOGRAM_SIZE = 256;
private static final int BUFFER_SIZE = 8192;
private static final int BUFFER_BYTES = BUFFER_SIZE * Long.BYTES;
private final Directory directory;
private final Bucket[] buckets = new Bucket[HISTOGRAM_SIZE];
private static class Bucket {
private final ByteBuffersDataOutput fps = new ByteBuffersDataOutput();
private final long[] buffer = new long[BUFFER_SIZE];
private IndexOutput output;
private int bufferUsed;
private int blockNum;
private long lastFp;
private int finalBlockSize;
private void addEntry(long l) throws IOException {
buffer[bufferUsed++] = l;
if (bufferUsed == BUFFER_SIZE) {
flush(false);
}
}
private void flush(boolean isFinal) throws IOException {
if (isFinal) {
finalBlockSize = bufferUsed;
}
long fp = output.getFilePointer();
fps.writeVLong(encode(fp - lastFp));
lastFp = fp;
for (int i = 0; i < bufferUsed; i++) {
output.writeLong(buffer[i]);
}
lastFp = fp;
blockNum++;
bufferUsed = 0;
}
private void reset(IndexOutput resetOutput) {
output = resetOutput;
finalBlockSize = 0;
bufferUsed = 0;
blockNum = 0;
lastFp = 0;
fps.reset();
}
}
private static long encode(long fpDelta) {
assert (fpDelta & 0x07) == 0 : "fpDelta should be multiple of 8";
if (fpDelta % BUFFER_BYTES == 0) {
return ((fpDelta / BUFFER_BYTES) << 1) | 1;
} else {
return fpDelta;
}
}
private static long decode(long fpDelta) {
if ((fpDelta & 1) == 1) {
return (fpDelta >>> 1) * BUFFER_BYTES;
} else {
return fpDelta;
}
}
ForwardIndexSorter(Directory directory) {
this.directory = directory;
for (int i = 0; i < HISTOGRAM_SIZE; i++) {
buckets[i] = new Bucket();
}
}
private void consume(String fileName, LongConsumer consumer) throws IOException {
try (IndexInput in = directory.openInput(fileName, IOContext.READONCE)) {
final long end = in.length() - CodecUtil.footerLength();
while (in.getFilePointer() < end) {
consumer.accept(in.readLong());
}
}
consumer.onFinish();
}
private void consume(String fileName, long indexFP, LongConsumer consumer) throws IOException {
try (IndexInput index = directory.openInput(fileName, IOContext.READONCE);
IndexInput value = directory.openInput(fileName, IOContext.READONCE)) {
index.seek(indexFP);
for (int i = 0; i < buckets.length; i++) {
int blockNum = index.readVInt();
int finalBlockSize = index.readVInt();
long fp = decode(index.readVLong());
for (int block = 0; block < blockNum - 1; block++) {
value.seek(fp);
for (int j = 0; j < BUFFER_SIZE; j++) {
consumer.accept(value.readLong());
}
fp += decode(index.readVLong());
}
value.seek(fp);
for (int j = 0; j < finalBlockSize; j++) {
consumer.accept(value.readLong());
}
}
consumer.onFinish();
}
}
private LongConsumer consumer(int shift) {
return new LongConsumer() {
@Override
public void accept(long value) throws IOException {
int b = (int) ((value >>> shift) & 0xFF);
Bucket bucket = buckets[b];
bucket.addEntry(value);
}
@Override
public void onFinish() throws IOException {
for (Bucket bucket : buckets) {
bucket.flush(true);
}
}
};
}
void sortAndConsume(String fileName, int maxDoc, LongConsumer consumer) throws IOException {
int bitsRequired = PackedInts.bitsRequired(maxDoc);
String sourceFileName = fileName;
long indexFP = -1;
for (int shift = 0; shift < bitsRequired; shift += 8) {
try (IndexOutput output = directory.createTempOutput(fileName, "sort", IOContext.DEFAULT)) {
Arrays.stream(buckets).forEach(b -> b.reset(output));
if (shift == 0) {
consume(sourceFileName, consumer(shift));
} else {
consume(sourceFileName, indexFP, consumer(shift));
directory.deleteFile(sourceFileName);
}
indexFP = output.getFilePointer();
for (Bucket bucket : buckets) {
output.writeVInt(bucket.blockNum);
output.writeVInt(bucket.finalBlockSize);
bucket.fps.copyTo(output);
}
CodecUtil.writeFooter(output);
sourceFileName = output.getName();
}
}
consume(sourceFileName, indexFP, consumer);
}
}
interface LongConsumer {
void accept(long value) throws IOException;
default void onFinish() throws IOException {}
}
} }

View File

@ -19,9 +19,13 @@ package org.apache.lucene.misc.index;
import static org.apache.lucene.misc.index.BPIndexReorderer.fastLog2; import static org.apache.lucene.misc.index.BPIndexReorderer.fastLog2;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread; import java.util.concurrent.ForkJoinWorkerThread;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.StoredField; import org.apache.lucene.document.StoredField;
@ -36,6 +40,8 @@ import org.apache.lucene.index.StoredFields;
import org.apache.lucene.store.ByteArrayDataInput; import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ByteArrayDataOutput; import org.apache.lucene.store.ByteArrayDataOutput;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.ArrayUtil;
@ -254,4 +260,74 @@ public class TestBPIndexReorderer extends LuceneTestCase {
assertArrayEquals( assertArrayEquals(
ArrayUtil.copyOfSubArray(ints, 0, len), ArrayUtil.copyOfSubArray(restored, 0, restoredLen)); ArrayUtil.copyOfSubArray(ints, 0, len), ArrayUtil.copyOfSubArray(restored, 0, restoredLen));
} }
public void testForwardIndexSorter() throws IOException {
class Entry implements Comparable<Entry> {
final int docId;
final int termId;
Entry(int docId, int termId) {
this.docId = docId;
this.termId = termId;
}
@Override
public int compareTo(Entry o) {
if (docId == o.docId) {
return Integer.compare(termId, o.termId);
} else {
return Integer.compare(docId, o.docId);
}
}
}
try (Directory directory = newDirectory()) {
for (int bits = 2; bits < 32; bits++) {
int maxDoc = (1 << bits) - 1;
int termNum = atLeast(100);
List<Entry> entryList = new ArrayList<>();
String fileName;
try (IndexOutput out =
directory.createTempOutput("testForwardIndexSorter", "sort", IOContext.DEFAULT)) {
for (int termId = 0; termId < termNum; termId++) {
int docNum = 0;
int doc = 0;
while (docNum < 100 && doc < maxDoc - 1) {
doc = random().nextInt(doc + 1, maxDoc);
assertTrue(doc >= 0);
docNum++;
entryList.add(new Entry(doc, termId));
out.writeLong((Integer.toUnsignedLong(termId) << 32) | Integer.toUnsignedLong(doc));
}
}
CodecUtil.writeFooter(out);
fileName = out.getName();
}
Collections.sort(entryList);
new BPIndexReorderer.ForwardIndexSorter(directory)
.sortAndConsume(
fileName,
maxDoc,
new BPIndexReorderer.LongConsumer() {
int total = 0;
@Override
public void accept(long value) {
int doc = (int) value;
int term = (int) (value >>> 32);
Entry entry = entryList.get(total);
assertEquals(entry.docId, doc);
assertEquals(entry.termId, term);
total++;
}
@Override
public void onFinish() {
assertEquals(entryList.size(), total);
}
});
}
}
}
} }