Cache buckets to speed up BytesRefHash#sort (#12784)

This commit is contained in:
gf2121 2023-11-10 16:13:41 +08:00 committed by GitHub
parent 904a994f66
commit d4583567e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 106 additions and 35 deletions

View File

@ -274,6 +274,8 @@ Optimizations
* GITHUB#12748: Specialize arc store for continuous label in FST. (Guo Feng, Chao Zhang) * GITHUB#12748: Specialize arc store for continuous label in FST. (Guo Feng, Chao Zhang)
* GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng)
Changes in runtime behavior Changes in runtime behavior
--------------------- ---------------------

View File

@ -263,11 +263,10 @@ class BufferedUpdates implements Accountable {
scratch.field = deleteFieldEntry.getKey(); scratch.field = deleteFieldEntry.getKey();
BufferedUpdates.BytesRefIntMap terms = deleteFieldEntry.getValue(); BufferedUpdates.BytesRefIntMap terms = deleteFieldEntry.getValue();
int[] indices = terms.bytesRefHash.sort(); int[] indices = terms.bytesRefHash.sort();
for (int index : indices) { for (int i = 0; i < terms.bytesRefHash.size(); i++) {
if (index != -1) { int index = indices[i];
terms.bytesRefHash.get(index, scratch.bytes); terms.bytesRefHash.get(index, scratch.bytes);
consumer.accept(scratch, terms.values[index]); consumer.accept(scratch, terms.values[index]);
}
} }
} }
} }

View File

@ -145,8 +145,63 @@ public final class BytesRefHash implements Accountable {
*/ */
public int[] sort() { public int[] sort() {
final int[] compact = compact(); final int[] compact = compact();
assert count * 2 <= compact.length : "We need load factor <= 0.5f to speed up this sort";
final int tmpOffset = count;
new StringSorter(BytesRefComparator.NATURAL) { new StringSorter(BytesRefComparator.NATURAL) {
@Override
protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBStringRadixSorter(cmp) {
private int k;
@Override
protected void buildHistogram(
int prefixCommonBucket,
int prefixCommonLen,
int from,
int to,
int k,
int[] histogram) {
this.k = k;
histogram[prefixCommonBucket] = prefixCommonLen;
Arrays.fill(
compact, tmpOffset + from - prefixCommonLen, tmpOffset + from, prefixCommonBucket);
for (int i = from; i < to; ++i) {
int b = getBucket(i, k);
compact[tmpOffset + i] = b;
histogram[b]++;
}
}
@Override
protected boolean shouldFallback(int from, int to, int l) {
// We lower the fallback threshold because the bucket cache speeds up the reorder
return to - from <= LENGTH_THRESHOLD / 2 || l >= LEVEL_THRESHOLD;
}
private void swapBucketCache(int i, int j) {
swap(i, j);
int tmp = compact[tmpOffset + i];
compact[tmpOffset + i] = compact[tmpOffset + j];
compact[tmpOffset + j] = tmp;
}
@Override
protected void reorder(int from, int to, int[] startOffsets, int[] endOffsets, int k) {
assert this.k == k;
for (int i = 0; i < HISTOGRAM_SIZE; ++i) {
final int limit = endOffsets[i];
for (int h1 = startOffsets[i]; h1 < limit; h1 = startOffsets[i]) {
final int b = compact[tmpOffset + from + h1];
final int h2 = startOffsets[b]++;
swapBucketCache(from + h1, from + h2);
}
}
}
};
}
@Override @Override
protected void swap(int i, int j) { protected void swap(int i, int j) {
int tmp = compact[i]; int tmp = compact[i];

View File

@ -32,11 +32,11 @@ public abstract class MSBRadixSorter extends Sorter {
// this is used as a protection against the fact that radix sort performs // this is used as a protection against the fact that radix sort performs
// worse when there are long common prefixes (probably because of cache // worse when there are long common prefixes (probably because of cache
// locality) // locality)
private static final int LEVEL_THRESHOLD = 8; protected static final int LEVEL_THRESHOLD = 8;
// size of histograms: 256 + 1 to indicate that the string is finished // size of histograms: 256 + 1 to indicate that the string is finished
protected static final int HISTOGRAM_SIZE = 257; protected static final int HISTOGRAM_SIZE = 257;
// buckets below this size will be sorted with introsort // buckets below this size will be sorted with fallback sorter
private static final int LENGTH_THRESHOLD = 100; protected static final int LENGTH_THRESHOLD = 100;
// we store one histogram per recursion level // we store one histogram per recursion level
private final int[][] histograms = new int[LEVEL_THRESHOLD][]; private final int[][] histograms = new int[LEVEL_THRESHOLD][];
@ -130,15 +130,15 @@ public abstract class MSBRadixSorter extends Sorter {
} }
protected void sort(int from, int to, int k, int l) { protected void sort(int from, int to, int k, int l) {
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) { if (shouldFallback(from, to, l)) {
introSort(from, to, k); getFallbackSorter(k).sort(from, to);
} else { } else {
radixSort(from, to, k, l); radixSort(from, to, k, l);
} }
} }
private void introSort(int from, int to, int k) { protected boolean shouldFallback(int from, int to, int l) {
getFallbackSorter(k).sort(from, to); return to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD;
} }
/** /**
@ -233,8 +233,6 @@ public abstract class MSBRadixSorter extends Sorter {
if (b != commonPrefix[j]) { if (b != commonPrefix[j]) {
commonPrefixLength = j; commonPrefixLength = j;
if (commonPrefixLength == 0) { // we have no common prefix if (commonPrefixLength == 0) { // we have no common prefix
histogram[commonPrefix[0] + 1] = i - from;
histogram[b + 1] = 1;
break outer; break outer;
} }
break; break;
@ -245,7 +243,7 @@ public abstract class MSBRadixSorter extends Sorter {
if (i < to) { if (i < to) {
// the loop got broken because there is no common prefix // the loop got broken because there is no common prefix
assert commonPrefixLength == 0; assert commonPrefixLength == 0;
buildHistogram(i + 1, to, k, histogram); buildHistogram(commonPrefix[0] + 1, i - from, i, to, k, histogram);
} else { } else {
assert commonPrefixLength > 0; assert commonPrefixLength > 0;
histogram[commonPrefix[0] + 1] = to - from; histogram[commonPrefix[0] + 1] = to - from;
@ -258,7 +256,9 @@ public abstract class MSBRadixSorter extends Sorter {
* Build an histogram of the k-th characters of values occurring between offsets {@code from} and * Build an histogram of the k-th characters of values occurring between offsets {@code from} and
* {@code to}, using {@link #getBucket}. * {@code to}, using {@link #getBucket}.
*/ */
private void buildHistogram(int from, int to, int k, int[] histogram) { protected void buildHistogram(
int prefixCommonBucket, int prefixCommonLen, int from, int to, int k, int[] histogram) {
histogram[prefixCommonBucket] = prefixCommonLen;
for (int i = from; i < to; ++i) { for (int i = from; i < to; ++i) {
histogram[getBucket(i, k)]++; histogram[getBucket(i, k)]++;
} }

View File

@ -78,7 +78,9 @@ abstract class StableStringSorter extends StringSorter {
@Override @Override
protected int compare(int i, int j) { protected int compare(int i, int j) {
return StableStringSorter.this.compare(i, j); get(scratch1, scratchBytes1, i);
get(scratch2, scratchBytes2, j);
return cmp.compare(scratchBytes1, scratchBytes2);
} }
@Override @Override

View File

@ -57,24 +57,35 @@ public abstract class StringSorter extends Sorter {
} }
} }
/** A radix sorter for {@link BytesRef} */
protected class MSBStringRadixSorter extends MSBRadixSorter {
private final BytesRefComparator cmp;
protected MSBStringRadixSorter(BytesRefComparator cmp) {
super(cmp.comparedBytesCount);
this.cmp = cmp;
}
@Override
protected void swap(int i, int j) {
StringSorter.this.swap(i, j);
}
@Override
protected int byteAt(int i, int k) {
get(scratch1, scratchBytes1, i);
return cmp.byteAt(scratchBytes1, k);
}
@Override
protected Sorter getFallbackSorter(int k) {
return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k));
}
}
protected Sorter radixSorter(BytesRefComparator cmp) { protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBRadixSorter(cmp.comparedBytesCount) { return new MSBStringRadixSorter(cmp);
@Override
protected void swap(int i, int j) {
StringSorter.this.swap(i, j);
}
@Override
protected int byteAt(int i, int k) {
get(scratch1, scratchBytes1, i);
return cmp.byteAt(scratchBytes1, k);
}
@Override
protected Sorter getFallbackSorter(int k) {
return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k));
}
};
} }
protected Sorter fallbackSorter(Comparator<BytesRef> cmp) { protected Sorter fallbackSorter(Comparator<BytesRef> cmp) {
@ -86,7 +97,9 @@ public abstract class StringSorter extends Sorter {
@Override @Override
protected int compare(int i, int j) { protected int compare(int i, int j) {
return StringSorter.this.compare(i, j); get(scratch1, scratchBytes1, i);
get(scratch2, scratchBytes2, j);
return cmp.compare(scratchBytes1, scratchBytes2);
} }
@Override @Override