Cache buckets to speed up BytesRefHash#sort (#12784)

This commit is contained in:
gf2121 2023-11-10 16:13:41 +08:00 committed by GitHub
parent 904a994f66
commit d4583567e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 35 deletions

View File

@ -274,6 +274,8 @@ Optimizations
* GITHUB#12748: Specialize arc store for continuous label in FST. (Guo Feng, Chao Zhang)
* GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng)
Changes in runtime behavior
---------------------

View File

@ -263,11 +263,10 @@ class BufferedUpdates implements Accountable {
scratch.field = deleteFieldEntry.getKey();
BufferedUpdates.BytesRefIntMap terms = deleteFieldEntry.getValue();
int[] indices = terms.bytesRefHash.sort();
for (int index : indices) {
if (index != -1) {
terms.bytesRefHash.get(index, scratch.bytes);
consumer.accept(scratch, terms.values[index]);
}
for (int i = 0; i < terms.bytesRefHash.size(); i++) {
int index = indices[i];
terms.bytesRefHash.get(index, scratch.bytes);
consumer.accept(scratch, terms.values[index]);
}
}
}

View File

@ -145,8 +145,63 @@ public final class BytesRefHash implements Accountable {
*/
public int[] sort() {
final int[] compact = compact();
assert count * 2 <= compact.length : "We need load factor <= 0.5f to speed up this sort";
final int tmpOffset = count;
new StringSorter(BytesRefComparator.NATURAL) {
@Override
protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBStringRadixSorter(cmp) {
private int k;
@Override
protected void buildHistogram(
int prefixCommonBucket,
int prefixCommonLen,
int from,
int to,
int k,
int[] histogram) {
this.k = k;
histogram[prefixCommonBucket] = prefixCommonLen;
Arrays.fill(
compact, tmpOffset + from - prefixCommonLen, tmpOffset + from, prefixCommonBucket);
for (int i = from; i < to; ++i) {
int b = getBucket(i, k);
compact[tmpOffset + i] = b;
histogram[b]++;
}
}
@Override
protected boolean shouldFallback(int from, int to, int l) {
// We lower the fallback threshold because the bucket cache speeds up the reorder
return to - from <= LENGTH_THRESHOLD / 2 || l >= LEVEL_THRESHOLD;
}
private void swapBucketCache(int i, int j) {
swap(i, j);
int tmp = compact[tmpOffset + i];
compact[tmpOffset + i] = compact[tmpOffset + j];
compact[tmpOffset + j] = tmp;
}
@Override
protected void reorder(int from, int to, int[] startOffsets, int[] endOffsets, int k) {
assert this.k == k;
for (int i = 0; i < HISTOGRAM_SIZE; ++i) {
final int limit = endOffsets[i];
for (int h1 = startOffsets[i]; h1 < limit; h1 = startOffsets[i]) {
final int b = compact[tmpOffset + from + h1];
final int h2 = startOffsets[b]++;
swapBucketCache(from + h1, from + h2);
}
}
}
};
}
@Override
protected void swap(int i, int j) {
int tmp = compact[i];

View File

@ -32,11 +32,11 @@ public abstract class MSBRadixSorter extends Sorter {
// this is used as a protection against the fact that radix sort performs
// worse when there are long common prefixes (probably because of cache
// locality)
private static final int LEVEL_THRESHOLD = 8;
protected static final int LEVEL_THRESHOLD = 8;
// size of histograms: 256 + 1 to indicate that the string is finished
protected static final int HISTOGRAM_SIZE = 257;
// buckets below this size will be sorted with introsort
private static final int LENGTH_THRESHOLD = 100;
// buckets below this size will be sorted with fallback sorter
protected static final int LENGTH_THRESHOLD = 100;
// we store one histogram per recursion level
private final int[][] histograms = new int[LEVEL_THRESHOLD][];
@ -130,15 +130,15 @@ public abstract class MSBRadixSorter extends Sorter {
}
protected void sort(int from, int to, int k, int l) {
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
introSort(from, to, k);
if (shouldFallback(from, to, l)) {
getFallbackSorter(k).sort(from, to);
} else {
radixSort(from, to, k, l);
}
}
private void introSort(int from, int to, int k) {
getFallbackSorter(k).sort(from, to);
protected boolean shouldFallback(int from, int to, int l) {
return to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD;
}
/**
@ -233,8 +233,6 @@ public abstract class MSBRadixSorter extends Sorter {
if (b != commonPrefix[j]) {
commonPrefixLength = j;
if (commonPrefixLength == 0) { // we have no common prefix
histogram[commonPrefix[0] + 1] = i - from;
histogram[b + 1] = 1;
break outer;
}
break;
@ -245,7 +243,7 @@ public abstract class MSBRadixSorter extends Sorter {
if (i < to) {
// the loop got broken because there is no common prefix
assert commonPrefixLength == 0;
buildHistogram(i + 1, to, k, histogram);
buildHistogram(commonPrefix[0] + 1, i - from, i, to, k, histogram);
} else {
assert commonPrefixLength > 0;
histogram[commonPrefix[0] + 1] = to - from;
@ -258,7 +256,9 @@ public abstract class MSBRadixSorter extends Sorter {
* Build an histogram of the k-th characters of values occurring between offsets {@code from} and
* {@code to}, using {@link #getBucket}.
*/
private void buildHistogram(int from, int to, int k, int[] histogram) {
protected void buildHistogram(
int prefixCommonBucket, int prefixCommonLen, int from, int to, int k, int[] histogram) {
histogram[prefixCommonBucket] = prefixCommonLen;
for (int i = from; i < to; ++i) {
histogram[getBucket(i, k)]++;
}

View File

@ -78,7 +78,9 @@ abstract class StableStringSorter extends StringSorter {
@Override
protected int compare(int i, int j) {
return StableStringSorter.this.compare(i, j);
get(scratch1, scratchBytes1, i);
get(scratch2, scratchBytes2, j);
return cmp.compare(scratchBytes1, scratchBytes2);
}
@Override

View File

@ -57,24 +57,35 @@ public abstract class StringSorter extends Sorter {
}
}
/** A radix sorter for {@link BytesRef} */
protected class MSBStringRadixSorter extends MSBRadixSorter {
private final BytesRefComparator cmp;
protected MSBStringRadixSorter(BytesRefComparator cmp) {
super(cmp.comparedBytesCount);
this.cmp = cmp;
}
@Override
protected void swap(int i, int j) {
StringSorter.this.swap(i, j);
}
@Override
protected int byteAt(int i, int k) {
get(scratch1, scratchBytes1, i);
return cmp.byteAt(scratchBytes1, k);
}
@Override
protected Sorter getFallbackSorter(int k) {
return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k));
}
}
protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBRadixSorter(cmp.comparedBytesCount) {
@Override
protected void swap(int i, int j) {
StringSorter.this.swap(i, j);
}
@Override
protected int byteAt(int i, int k) {
get(scratch1, scratchBytes1, i);
return cmp.byteAt(scratchBytes1, k);
}
@Override
protected Sorter getFallbackSorter(int k) {
return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k));
}
};
return new MSBStringRadixSorter(cmp);
}
protected Sorter fallbackSorter(Comparator<BytesRef> cmp) {
@ -86,7 +97,9 @@ public abstract class StringSorter extends Sorter {
@Override
protected int compare(int i, int j) {
return StringSorter.this.compare(i, j);
get(scratch1, scratchBytes1, i);
get(scratch2, scratchBytes2, j);
return cmp.compare(scratchBytes1, scratchBytes2);
}
@Override