From d4583567e9cf09ecdeeedd8c1f7389676fd80bd4 Mon Sep 17 00:00:00 2001 From: gf2121 <52390227+gf2121@users.noreply.github.com> Date: Fri, 10 Nov 2023 16:13:41 +0800 Subject: [PATCH] Cache buckets to speed up BytesRefHash#sort (#12784) --- lucene/CHANGES.txt | 2 + .../apache/lucene/index/BufferedUpdates.java | 9 ++- .../org/apache/lucene/util/BytesRefHash.java | 55 +++++++++++++++++++ .../apache/lucene/util/MSBRadixSorter.java | 22 ++++---- .../lucene/util/StableStringSorter.java | 4 +- .../org/apache/lucene/util/StringSorter.java | 49 +++++++++++------ 6 files changed, 106 insertions(+), 35 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 613c5d85823..5851311f357 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -274,6 +274,8 @@ Optimizations * GITHUB#12748: Specialize arc store for continuous label in FST. (Guo Feng, Chao Zhang) +* GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng) + Changes in runtime behavior --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/index/BufferedUpdates.java b/lucene/core/src/java/org/apache/lucene/index/BufferedUpdates.java index a33d55248c2..0f48c261397 100644 --- a/lucene/core/src/java/org/apache/lucene/index/BufferedUpdates.java +++ b/lucene/core/src/java/org/apache/lucene/index/BufferedUpdates.java @@ -263,11 +263,10 @@ class BufferedUpdates implements Accountable { scratch.field = deleteFieldEntry.getKey(); BufferedUpdates.BytesRefIntMap terms = deleteFieldEntry.getValue(); int[] indices = terms.bytesRefHash.sort(); - for (int index : indices) { - if (index != -1) { - terms.bytesRefHash.get(index, scratch.bytes); - consumer.accept(scratch, terms.values[index]); - } + for (int i = 0; i < terms.bytesRefHash.size(); i++) { + int index = indices[i]; + terms.bytesRefHash.get(index, scratch.bytes); + consumer.accept(scratch, terms.values[index]); } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java b/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java index 1359b13c737..6e1b8fc385e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java +++ b/lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java @@ -145,8 +145,63 @@ public final class BytesRefHash implements Accountable { */ public int[] sort() { final int[] compact = compact(); + assert count * 2 <= compact.length : "We need load factor <= 0.5f to speed up this sort"; + final int tmpOffset = count; new StringSorter(BytesRefComparator.NATURAL) { + @Override + protected Sorter radixSorter(BytesRefComparator cmp) { + return new MSBStringRadixSorter(cmp) { + + private int k; + + @Override + protected void buildHistogram( + int prefixCommonBucket, + int prefixCommonLen, + int from, + int to, + int k, + int[] histogram) { + this.k = k; + histogram[prefixCommonBucket] = prefixCommonLen; + Arrays.fill( + compact, tmpOffset + from - prefixCommonLen, tmpOffset + from, prefixCommonBucket); + for (int i = from; i < to; ++i) { + int b = getBucket(i, k); + compact[tmpOffset + i] = b; + histogram[b]++; + } + } + + @Override + protected boolean shouldFallback(int from, int to, int l) { + // We lower the fallback threshold because the bucket cache speeds up the reorder + return to - from <= LENGTH_THRESHOLD / 2 || l >= LEVEL_THRESHOLD; + } + + private void swapBucketCache(int i, int j) { + swap(i, j); + int tmp = compact[tmpOffset + i]; + compact[tmpOffset + i] = compact[tmpOffset + j]; + compact[tmpOffset + j] = tmp; + } + + @Override + protected void reorder(int from, int to, int[] startOffsets, int[] endOffsets, int k) { + assert this.k == k; + for (int i = 0; i < HISTOGRAM_SIZE; ++i) { + final int limit = endOffsets[i]; + for (int h1 = startOffsets[i]; h1 < limit; h1 = startOffsets[i]) { + final int b = compact[tmpOffset + from + h1]; + final int h2 = startOffsets[b]++; + swapBucketCache(from + h1, from + h2); + } + } + } + }; + } + @Override protected void swap(int i, int j) { int tmp = compact[i]; diff --git a/lucene/core/src/java/org/apache/lucene/util/MSBRadixSorter.java b/lucene/core/src/java/org/apache/lucene/util/MSBRadixSorter.java index ebfc31077c1..f471a3f53fc 100644 --- a/lucene/core/src/java/org/apache/lucene/util/MSBRadixSorter.java +++ b/lucene/core/src/java/org/apache/lucene/util/MSBRadixSorter.java @@ -32,11 +32,11 @@ public abstract class MSBRadixSorter extends Sorter { // this is used as a protection against the fact that radix sort performs // worse when there are long common prefixes (probably because of cache // locality) - private static final int LEVEL_THRESHOLD = 8; + protected static final int LEVEL_THRESHOLD = 8; // size of histograms: 256 + 1 to indicate that the string is finished protected static final int HISTOGRAM_SIZE = 257; - // buckets below this size will be sorted with introsort - private static final int LENGTH_THRESHOLD = 100; + // buckets below this size will be sorted with fallback sorter + protected static final int LENGTH_THRESHOLD = 100; // we store one histogram per recursion level private final int[][] histograms = new int[LEVEL_THRESHOLD][]; @@ -130,15 +130,15 @@ public abstract class MSBRadixSorter extends Sorter { } protected void sort(int from, int to, int k, int l) { - if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) { - introSort(from, to, k); + if (shouldFallback(from, to, l)) { + getFallbackSorter(k).sort(from, to); } else { radixSort(from, to, k, l); } } - private void introSort(int from, int to, int k) { - getFallbackSorter(k).sort(from, to); + protected boolean shouldFallback(int from, int to, int l) { + return to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD; } /** @@ -233,8 +233,6 @@ public abstract class MSBRadixSorter extends Sorter { if (b != commonPrefix[j]) { commonPrefixLength = j; if (commonPrefixLength == 0) { // we have no common prefix - histogram[commonPrefix[0] + 1] = i - from; - histogram[b + 1] = 1; break outer; } break; @@ -245,7 +243,7 @@ public abstract class MSBRadixSorter extends Sorter { if (i < to) { // the loop got broken because there is no common prefix assert commonPrefixLength == 0; - buildHistogram(i + 1, to, k, histogram); + buildHistogram(commonPrefix[0] + 1, i - from, i, to, k, histogram); } else { assert commonPrefixLength > 0; histogram[commonPrefix[0] + 1] = to - from; @@ -258,7 +256,9 @@ public abstract class MSBRadixSorter extends Sorter { * Build an histogram of the k-th characters of values occurring between offsets {@code from} and * {@code to}, using {@link #getBucket}. */ - private void buildHistogram(int from, int to, int k, int[] histogram) { + protected void buildHistogram( + int prefixCommonBucket, int prefixCommonLen, int from, int to, int k, int[] histogram) { + histogram[prefixCommonBucket] = prefixCommonLen; for (int i = from; i < to; ++i) { histogram[getBucket(i, k)]++; } diff --git a/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java b/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java index 9d08ade563f..da3683e723c 100644 --- a/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java +++ b/lucene/core/src/java/org/apache/lucene/util/StableStringSorter.java @@ -78,7 +78,9 @@ abstract class StableStringSorter extends StringSorter { @Override protected int compare(int i, int j) { - return StableStringSorter.this.compare(i, j); + get(scratch1, scratchBytes1, i); + get(scratch2, scratchBytes2, j); + return cmp.compare(scratchBytes1, scratchBytes2); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/StringSorter.java b/lucene/core/src/java/org/apache/lucene/util/StringSorter.java index 1ae67716d17..71987d77c9d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/StringSorter.java +++ b/lucene/core/src/java/org/apache/lucene/util/StringSorter.java @@ -57,24 +57,35 @@ public abstract class StringSorter extends Sorter { } } + /** A radix sorter for {@link BytesRef} */ + protected class MSBStringRadixSorter extends MSBRadixSorter { + + private final BytesRefComparator cmp; + + protected MSBStringRadixSorter(BytesRefComparator cmp) { + super(cmp.comparedBytesCount); + this.cmp = cmp; + } + + @Override + protected void swap(int i, int j) { + StringSorter.this.swap(i, j); + } + + @Override + protected int byteAt(int i, int k) { + get(scratch1, scratchBytes1, i); + return cmp.byteAt(scratchBytes1, k); + } + + @Override + protected Sorter getFallbackSorter(int k) { + return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k)); + } + } + protected Sorter radixSorter(BytesRefComparator cmp) { - return new MSBRadixSorter(cmp.comparedBytesCount) { - @Override - protected void swap(int i, int j) { - StringSorter.this.swap(i, j); - } - - @Override - protected int byteAt(int i, int k) { - get(scratch1, scratchBytes1, i); - return cmp.byteAt(scratchBytes1, k); - } - - @Override - protected Sorter getFallbackSorter(int k) { - return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k)); - } - }; + return new MSBStringRadixSorter(cmp); } protected Sorter fallbackSorter(Comparator cmp) { @@ -86,7 +97,9 @@ public abstract class StringSorter extends Sorter { @Override protected int compare(int i, int j) { - return StringSorter.this.compare(i, j); + get(scratch1, scratchBytes1, i); + get(scratch2, scratchBytes2, j); + return cmp.compare(scratchBytes1, scratchBytes2); } @Override