From 779592771ade292323239dadf1d98c5e9632c793 Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Wed, 25 Oct 2023 13:36:52 +0800 Subject: [PATCH] Speed up the sort when building forward index (#12712) --- lucene/CHANGES.txt | 2 + .../lucene/misc/index/BPIndexReorderer.java | 312 ++++++++++++------ .../misc/index/TestBPIndexReorderer.java | 76 +++++ 3 files changed, 291 insertions(+), 99 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index d351c549319..a3e7cb8ed27 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -225,6 +225,8 @@ Optimizations * GITHUB#12710: Use Arrays#mismatch for Outputs#common operations. (Guo Feng) +* GITHUB#12712: Speed up sorting postings file with an offline radix sorter in BPIndexReader. (Guo Feng) + Changes in runtime behavior --------------------- diff --git a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java index 3ed6be41317..dd885f40c74 100644 --- a/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java +++ b/lucene/misc/src/java/org/apache/lucene/misc/index/BPIndexReorderer.java @@ -35,7 +35,7 @@ import org.apache.lucene.index.SortingCodecReader; import org.apache.lucene.index.Terms; import org.apache.lucene.index.TermsEnum; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.ByteBuffersDataOutput; import org.apache.lucene.store.DataInput; import org.apache.lucene.store.DataOutput; import org.apache.lucene.store.Directory; @@ -46,13 +46,11 @@ import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.store.TrackingDirectoryWrapper; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.BytesRef; -import org.apache.lucene.util.BytesRefComparator; import org.apache.lucene.util.CloseableThreadLocal; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.IntsRef; -import org.apache.lucene.util.OfflineSorter; -import org.apache.lucene.util.OfflineSorter.BufferSize; +import org.apache.lucene.util.packed.PackedInts; /** * Implementation of "recursive graph bisection", also called "bipartite graph partitioning" and @@ -654,9 +652,7 @@ public final class BPIndexReorderer { for (int doc = postings.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = postings.nextDoc()) { - // reverse bytes so that byte order matches natural order - postingsOut.writeInt(Integer.reverseBytes(doc)); - postingsOut.writeInt(Integer.reverseBytes(termID)); + postingsOut.writeLong(Integer.toUnsignedLong(termID) << 32 | Integer.toUnsignedLong(doc)); } } } @@ -665,107 +661,60 @@ public final class BPIndexReorderer { private ForwardIndex buildForwardIndex( Directory tempDir, String postingsFileName, int maxDoc, int maxTerm) throws IOException { - String sortedPostingsFile = - new OfflineSorter( - tempDir, - "forward-index", - // Implement BytesRefComparator to make OfflineSorter use radix sort - new BytesRefComparator(2 * Integer.BYTES) { - @Override - protected int byteAt(BytesRef ref, int i) { - return ref.bytes[ref.offset + i] & 0xFF; - } - - @Override - public int compare(BytesRef o1, BytesRef o2, int k) { - assert o1.length == 2 * Integer.BYTES; - assert o2.length == 2 * Integer.BYTES; - return ArrayUtil.compareUnsigned8(o1.bytes, o1.offset, o2.bytes, o2.offset); - } - }, - BufferSize.megabytes((long) (ramBudgetMB / getParallelism())), - OfflineSorter.MAX_TEMPFILES, - 2 * Integer.BYTES, - forkJoinPool, - getParallelism()) { - - @Override - protected ByteSequencesReader getReader(ChecksumIndexInput in, String name) - throws IOException { - return new ByteSequencesReader(in, postingsFileName) { - { - ref.grow(2 * Integer.BYTES); - ref.setLength(2 * Integer.BYTES); - } - - @Override - public BytesRef next() throws IOException { - if (in.getFilePointer() >= end) { - return null; - } - // optimized read of 8 bytes - in.readBytes(ref.bytes(), 0, 2 * Integer.BYTES); - return ref.get(); - } - }; - } - - @Override - protected ByteSequencesWriter getWriter(IndexOutput out, long itemCount) - throws IOException { - return new ByteSequencesWriter(out) { - @Override - public void write(byte[] bytes, int off, int len) throws IOException { - assert len == 2 * Integer.BYTES; - // optimized read of 8 bytes - out.writeBytes(bytes, off, len); - } - }; - } - }.sort(postingsFileName); String termIDsFileName; String startOffsetsFileName; - int prevDoc = -1; - try (IndexInput sortedPostings = tempDir.openInput(sortedPostingsFile, IOContext.READONCE); - IndexOutput termIDs = tempDir.createTempOutput("term-ids", "", IOContext.DEFAULT); + try (IndexOutput termIDs = tempDir.createTempOutput("term-ids", "", IOContext.DEFAULT); IndexOutput startOffsets = tempDir.createTempOutput("start-offsets", "", IOContext.DEFAULT)) { termIDsFileName = termIDs.getName(); startOffsetsFileName = startOffsets.getName(); - final long end = sortedPostings.length() - CodecUtil.footerLength(); int[] buffer = new int[TERM_IDS_BLOCK_SIZE]; - int bufferLen = 0; - while (sortedPostings.getFilePointer() < end) { - final int doc = Integer.reverseBytes(sortedPostings.readInt()); - final int termID = Integer.reverseBytes(sortedPostings.readInt()); - if (doc != prevDoc) { - if (bufferLen != 0) { - writeMonotonicInts(buffer, bufferLen, termIDs); - bufferLen = 0; - } + new ForwardIndexSorter(tempDir) + .sortAndConsume( + postingsFileName, + maxDoc, + new LongConsumer() { - assert doc > prevDoc; - for (int d = prevDoc + 1; d <= doc; ++d) { - startOffsets.writeLong(termIDs.getFilePointer()); - } - prevDoc = doc; - } - assert termID < maxTerm : termID + " " + maxTerm; - if (bufferLen == buffer.length) { - writeMonotonicInts(buffer, bufferLen, termIDs); - bufferLen = 0; - } - buffer[bufferLen++] = termID; - } - if (bufferLen != 0) { - writeMonotonicInts(buffer, bufferLen, termIDs); - } - for (int d = prevDoc + 1; d <= maxDoc; ++d) { - startOffsets.writeLong(termIDs.getFilePointer()); - } - CodecUtil.writeFooter(termIDs); - CodecUtil.writeFooter(startOffsets); + int prevDoc = -1; + int bufferLen = 0; + + @Override + public void accept(long value) throws IOException { + int doc = (int) value; + int termID = (int) (value >>> 32); + if (doc != prevDoc) { + if (bufferLen != 0) { + writeMonotonicInts(buffer, bufferLen, termIDs); + bufferLen = 0; + } + + assert doc > prevDoc; + for (int d = prevDoc + 1; d <= doc; ++d) { + startOffsets.writeLong(termIDs.getFilePointer()); + } + prevDoc = doc; + } + assert termID < maxTerm : termID + " " + maxTerm; + if (bufferLen == buffer.length) { + writeMonotonicInts(buffer, bufferLen, termIDs); + bufferLen = 0; + } + buffer[bufferLen++] = termID; + } + + @Override + public void onFinish() throws IOException { + if (bufferLen != 0) { + writeMonotonicInts(buffer, bufferLen, termIDs); + } + for (int d = prevDoc + 1; d <= maxDoc; ++d) { + startOffsets.writeLong(termIDs.getFilePointer()); + } + CodecUtil.writeFooter(termIDs); + CodecUtil.writeFooter(startOffsets); + } + }); } IndexInput termIDsInput = tempDir.openInput(termIDsFileName, IOContext.READ); @@ -991,4 +940,169 @@ public final class BPIndexReorderer { } return len; } + + /** + * Use a LSB Radix Sorter to sort the (docID, termID) entries. We only need to compare docIds + * because LSB Radix Sorter is stable and termIDs already sorted. + * + *

This sorter will require at least 16MB ({@link #BUFFER_BYTES} * {@link #HISTOGRAM_SIZE}) + * RAM. + */ + static class ForwardIndexSorter { + + private static final int HISTOGRAM_SIZE = 256; + private static final int BUFFER_SIZE = 8192; + private static final int BUFFER_BYTES = BUFFER_SIZE * Long.BYTES; + private final Directory directory; + private final Bucket[] buckets = new Bucket[HISTOGRAM_SIZE]; + + private static class Bucket { + private final ByteBuffersDataOutput fps = new ByteBuffersDataOutput(); + private final long[] buffer = new long[BUFFER_SIZE]; + private IndexOutput output; + private int bufferUsed; + private int blockNum; + private long lastFp; + private int finalBlockSize; + + private void addEntry(long l) throws IOException { + buffer[bufferUsed++] = l; + if (bufferUsed == BUFFER_SIZE) { + flush(false); + } + } + + private void flush(boolean isFinal) throws IOException { + if (isFinal) { + finalBlockSize = bufferUsed; + } + long fp = output.getFilePointer(); + fps.writeVLong(encode(fp - lastFp)); + lastFp = fp; + for (int i = 0; i < bufferUsed; i++) { + output.writeLong(buffer[i]); + } + lastFp = fp; + blockNum++; + bufferUsed = 0; + } + + private void reset(IndexOutput resetOutput) { + output = resetOutput; + finalBlockSize = 0; + bufferUsed = 0; + blockNum = 0; + lastFp = 0; + fps.reset(); + } + } + + private static long encode(long fpDelta) { + assert (fpDelta & 0x07) == 0 : "fpDelta should be multiple of 8"; + if (fpDelta % BUFFER_BYTES == 0) { + return ((fpDelta / BUFFER_BYTES) << 1) | 1; + } else { + return fpDelta; + } + } + + private static long decode(long fpDelta) { + if ((fpDelta & 1) == 1) { + return (fpDelta >>> 1) * BUFFER_BYTES; + } else { + return fpDelta; + } + } + + ForwardIndexSorter(Directory directory) { + this.directory = directory; + for (int i = 0; i < HISTOGRAM_SIZE; i++) { + buckets[i] = new Bucket(); + } + } + + private void consume(String fileName, LongConsumer consumer) throws IOException { + try (IndexInput in = directory.openInput(fileName, IOContext.READONCE)) { + final long end = in.length() - CodecUtil.footerLength(); + while (in.getFilePointer() < end) { + consumer.accept(in.readLong()); + } + } + consumer.onFinish(); + } + + private void consume(String fileName, long indexFP, LongConsumer consumer) throws IOException { + try (IndexInput index = directory.openInput(fileName, IOContext.READONCE); + IndexInput value = directory.openInput(fileName, IOContext.READONCE)) { + index.seek(indexFP); + for (int i = 0; i < buckets.length; i++) { + int blockNum = index.readVInt(); + int finalBlockSize = index.readVInt(); + long fp = decode(index.readVLong()); + for (int block = 0; block < blockNum - 1; block++) { + value.seek(fp); + for (int j = 0; j < BUFFER_SIZE; j++) { + consumer.accept(value.readLong()); + } + fp += decode(index.readVLong()); + } + value.seek(fp); + for (int j = 0; j < finalBlockSize; j++) { + consumer.accept(value.readLong()); + } + } + consumer.onFinish(); + } + } + + private LongConsumer consumer(int shift) { + return new LongConsumer() { + @Override + public void accept(long value) throws IOException { + int b = (int) ((value >>> shift) & 0xFF); + Bucket bucket = buckets[b]; + bucket.addEntry(value); + } + + @Override + public void onFinish() throws IOException { + for (Bucket bucket : buckets) { + bucket.flush(true); + } + } + }; + } + + void sortAndConsume(String fileName, int maxDoc, LongConsumer consumer) throws IOException { + int bitsRequired = PackedInts.bitsRequired(maxDoc); + String sourceFileName = fileName; + long indexFP = -1; + for (int shift = 0; shift < bitsRequired; shift += 8) { + try (IndexOutput output = directory.createTempOutput(fileName, "sort", IOContext.DEFAULT)) { + Arrays.stream(buckets).forEach(b -> b.reset(output)); + if (shift == 0) { + consume(sourceFileName, consumer(shift)); + } else { + consume(sourceFileName, indexFP, consumer(shift)); + directory.deleteFile(sourceFileName); + } + indexFP = output.getFilePointer(); + for (Bucket bucket : buckets) { + output.writeVInt(bucket.blockNum); + output.writeVInt(bucket.finalBlockSize); + bucket.fps.copyTo(output); + } + CodecUtil.writeFooter(output); + sourceFileName = output.getName(); + } + } + consume(sourceFileName, indexFP, consumer); + } + } + + interface LongConsumer { + void accept(long value) throws IOException; + + default void onFinish() throws IOException {} + } } diff --git a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java index 13d6989ff74..491a96185ed 100644 --- a/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java +++ b/lucene/misc/src/test/org/apache/lucene/misc/index/TestBPIndexReorderer.java @@ -19,9 +19,13 @@ package org.apache.lucene.misc.index; import static org.apache.lucene.misc.index.BPIndexReorderer.fastLog2; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.List; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ForkJoinWorkerThread; +import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.StoredField; @@ -36,6 +40,8 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.store.ByteArrayDataInput; import org.apache.lucene.store.ByteArrayDataOutput; import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.ArrayUtil; @@ -254,4 +260,74 @@ public class TestBPIndexReorderer extends LuceneTestCase { assertArrayEquals( ArrayUtil.copyOfSubArray(ints, 0, len), ArrayUtil.copyOfSubArray(restored, 0, restoredLen)); } + + public void testForwardIndexSorter() throws IOException { + class Entry implements Comparable { + final int docId; + final int termId; + + Entry(int docId, int termId) { + this.docId = docId; + this.termId = termId; + } + + @Override + public int compareTo(Entry o) { + if (docId == o.docId) { + return Integer.compare(termId, o.termId); + } else { + return Integer.compare(docId, o.docId); + } + } + } + + try (Directory directory = newDirectory()) { + for (int bits = 2; bits < 32; bits++) { + int maxDoc = (1 << bits) - 1; + int termNum = atLeast(100); + List entryList = new ArrayList<>(); + String fileName; + try (IndexOutput out = + directory.createTempOutput("testForwardIndexSorter", "sort", IOContext.DEFAULT)) { + for (int termId = 0; termId < termNum; termId++) { + int docNum = 0; + int doc = 0; + while (docNum < 100 && doc < maxDoc - 1) { + doc = random().nextInt(doc + 1, maxDoc); + assertTrue(doc >= 0); + docNum++; + entryList.add(new Entry(doc, termId)); + out.writeLong((Integer.toUnsignedLong(termId) << 32) | Integer.toUnsignedLong(doc)); + } + } + CodecUtil.writeFooter(out); + fileName = out.getName(); + } + Collections.sort(entryList); + new BPIndexReorderer.ForwardIndexSorter(directory) + .sortAndConsume( + fileName, + maxDoc, + new BPIndexReorderer.LongConsumer() { + + int total = 0; + + @Override + public void accept(long value) { + int doc = (int) value; + int term = (int) (value >>> 32); + Entry entry = entryList.get(total); + assertEquals(entry.docId, doc); + assertEquals(entry.termId, term); + total++; + } + + @Override + public void onFinish() { + assertEquals(entryList.size(), total); + } + }); + } + } + } }