LUCENE-10196: Improve IntroSorter with 3-ways partitioning.

This commit is contained in:
Bruno Roustant 2021-10-21 16:18:32 +02:00
parent 0544819b78
commit 63b9e603e6
No known key found for this signature in database
GPG Key ID: 5005617B545C7FB0
5 changed files with 265 additions and 62 deletions

View File

@ -474,6 +474,8 @@ Optimizations
postings in memory, and reduced a bit of RAM overhead in
IndexWriter's internal postings book-keeping (mashudong)
* LUCENE-10196: Improve IntroSorter with 3-ways partitioning. (Bruno Roustant)
Bug Fixes
---------------------

View File

@ -20,66 +20,124 @@ package org.apache.lucene.util;
* {@link Sorter} implementation based on a variant of the quicksort algorithm called <a
* href="http://en.wikipedia.org/wiki/Introsort">introsort</a>: when the recursion level exceeds the
* log of the length of the array to sort, it falls back to heapsort. This prevents quicksort from
* running into its worst-case quadratic runtime. Small arrays are sorted with insertion sort.
* running into its worst-case quadratic runtime. Small ranges are sorted with insertion sort.
*
* <p>This sort algorithm is fast on most data shapes, especially with low cardinality. If the data
* to sort is known to be strictly ascending or descending, prefer {@link TimSorter}.
*
* @lucene.internal
*/
public abstract class IntroSorter extends Sorter {
/** Below this size threshold, the partition selection is simplified to a single median. */
private static final int SINGLE_MEDIAN_THRESHOLD = 40;
/** Create a new {@link IntroSorter}. */
public IntroSorter() {}
@Override
public final void sort(int from, int to) {
checkRange(from, to);
quicksort(from, to, 2 * MathUtil.log(to - from, 2));
sort(from, to, 2 * MathUtil.log(to - from, 2));
}
void quicksort(int from, int to, int maxDepth) {
if (to - from < BINARY_SORT_THRESHOLD) {
binarySort(from, to);
return;
} else if (--maxDepth < 0) {
heapSort(from, to);
return;
}
/**
* Sorts between from (inclusive) and to (exclusive) with intro sort.
*
* <p>Sorts small ranges with insertion sort. Fallbacks to heap sort to avoid quadratic worst
* case. Selects the pivot with medians and partitions with the Bentley-McIlroy fast 3-ways
* algorithm (Engineering a Sort Function, Bentley-McIlroy).
*/
void sort(int from, int to, int maxDepth) {
int size;
final int mid = (from + to) >>> 1;
// Sort small ranges with insertion sort.
while ((size = to - from) > INSERTION_SORT_THRESHOLD) {
if (compare(from, mid) > 0) {
swap(from, mid);
}
if (compare(mid, to - 1) > 0) {
swap(mid, to - 1);
if (compare(from, mid) > 0) {
swap(from, mid);
}
}
int left = from + 1;
int right = to - 2;
setPivot(mid);
for (; ; ) {
while (comparePivot(right) < 0) {
--right;
if (--maxDepth < 0) {
// Max recursion depth reached: fallback to heap sort.
heapSort(from, to);
return;
}
while (left < right && comparePivot(left) >= 0) {
++left;
}
if (left < right) {
swap(left, right);
--right;
// Pivot selection based on medians.
int last = to - 1;
int mid = (from + last) >>> 1;
int pivot;
if (size <= SINGLE_MEDIAN_THRESHOLD) {
// Select the pivot with a single median around the middle element.
// Do not take the median between [from, mid, last] because it hurts performance
// if the order is descending.
int range = size >> 2;
pivot = median(mid - range, mid, mid + range);
} else {
break;
// Select the pivot with the median of medians.
int range = size >> 3;
int doubleRange = range << 1;
int medianFirst = median(from, from + range, from + doubleRange);
int medianMiddle = median(mid - range, mid, mid + range);
int medianLast = median(last - doubleRange, last - range, last);
pivot = median(medianFirst, medianMiddle, medianLast);
}
// Bentley-McIlroy 3-way partitioning.
setPivot(pivot);
swap(from, pivot);
int i = from;
int j = to;
int p = from + 1;
int q = last;
while (true) {
int leftCmp, rightCmp;
while ((leftCmp = comparePivot(++i)) > 0) {}
while ((rightCmp = comparePivot(--j)) < 0) {}
if (i >= j) {
if (i == j && rightCmp == 0) {
swap(i, p);
}
break;
}
swap(i, j);
if (rightCmp == 0) {
swap(i, p++);
}
if (leftCmp == 0) {
swap(j, q--);
}
}
i = j + 1;
for (int k = from; k < p; ) {
swap(k++, j--);
}
for (int k = last; k > q; ) {
swap(k--, i++);
}
// Recursion on the smallest partition. Replace the tail recursion by a loop.
if (j - from < last - i) {
sort(from, j + 1, maxDepth);
from = i;
} else {
sort(i, to, maxDepth);
to = j + 1;
}
}
quicksort(from, left + 1, maxDepth);
quicksort(left + 1, to, maxDepth);
insertionSort(from, to);
}
/** Returns the index of the median element among three elements at provided indices. */
private int median(int i, int j, int k) {
if (compare(i, j) < 0) {
if (compare(j, k) <= 0) {
return j;
}
return compare(i, k) < 0 ? k : i;
}
if (compare(j, k) >= 0) {
return j;
}
return compare(i, k) < 0 ? i : k;
}
// Don't rely on the slow default impl of setPivot/comparePivot since

View File

@ -27,6 +27,9 @@ public abstract class Sorter {
static final int BINARY_SORT_THRESHOLD = 20;
/** Below this size threshold, the sub-range is sorted using Insertion sort. */
static final int INSERTION_SORT_THRESHOLD = 16;
/** Sole constructor, used for inheritance. */
protected Sorter() {}
@ -190,7 +193,7 @@ public abstract class Sorter {
/**
* A binary sort implementation. This performs {@code O(n*log(n))} comparisons and {@code O(n^2)}
* swaps. It is typically used by more sophisticated implementations as a fall-back when the
* numbers of items to sort has become less than {@value #BINARY_SORT_THRESHOLD}.
* number of items to sort has become less than {@value #BINARY_SORT_THRESHOLD}.
*/
void binarySort(int from, int to) {
binarySort(from, to, from + 1);
@ -216,6 +219,25 @@ public abstract class Sorter {
}
}
/**
* Sorts between from (inclusive) and to (exclusive) with insertion sort. Runs in {@code O(n^2)}.
* It is typically used by more sophisticated implementations as a fall-back when the number of
* items to sort becomes less than {@value #INSERTION_SORT_THRESHOLD}.
*/
void insertionSort(int from, int to) {
for (int i = from + 1; i < to; ) {
int current = i++;
int previous;
while (compare((previous = current - 1), current) > 0) {
swap(previous, current);
if (previous == from) {
break;
}
current = previous;
}
}
}
/**
* Use heap sort to sort items between {@code from} inclusive and {@code to} exclusive. This runs
* in {@code O(n*log(n))} and is used as a fall-back by {@link IntroSorter}.

View File

@ -17,6 +17,7 @@
package org.apache.lucene.util;
import java.util.Arrays;
import java.util.Random;
public abstract class BaseSortTestCase extends LuceneTestCase {
@ -32,7 +33,7 @@ public abstract class BaseSortTestCase extends LuceneTestCase {
@Override
public int compareTo(Entry other) {
return value < other.value ? -1 : value == other.value ? 0 : 1;
return Integer.compare(value, other.value);
}
}
@ -68,70 +69,78 @@ public abstract class BaseSortTestCase extends LuceneTestCase {
enum Strategy {
RANDOM {
@Override
public void set(Entry[] arr, int i) {
arr[i] = new Entry(random().nextInt(), i);
public void set(Entry[] arr, int i, Random random) {
arr[i] = new Entry(random.nextInt(), i);
}
},
RANDOM_LOW_CARDINALITY {
@Override
public void set(Entry[] arr, int i) {
arr[i] = new Entry(random().nextInt(6), i);
public void set(Entry[] arr, int i, Random random) {
arr[i] = new Entry(random.nextInt(6), i);
}
},
RANDOM_MEDIUM_CARDINALITY {
@Override
public void set(Entry[] arr, int i, Random random) {
arr[i] = new Entry(random.nextInt(arr.length / 2), i);
}
},
ASCENDING {
@Override
public void set(Entry[] arr, int i) {
public void set(Entry[] arr, int i, Random random) {
arr[i] =
i == 0
? new Entry(random().nextInt(6), 0)
: new Entry(arr[i - 1].value + random().nextInt(6), i);
? new Entry(random.nextInt(6), 0)
: new Entry(arr[i - 1].value + random.nextInt(6), i);
}
},
DESCENDING {
@Override
public void set(Entry[] arr, int i) {
public void set(Entry[] arr, int i, Random random) {
arr[i] =
i == 0
? new Entry(random().nextInt(6), 0)
: new Entry(arr[i - 1].value - random().nextInt(6), i);
? new Entry(random.nextInt(6), 0)
: new Entry(arr[i - 1].value - random.nextInt(6), i);
}
},
STRICTLY_DESCENDING {
@Override
public void set(Entry[] arr, int i) {
public void set(Entry[] arr, int i, Random random) {
arr[i] =
i == 0
? new Entry(random().nextInt(6), 0)
: new Entry(arr[i - 1].value - TestUtil.nextInt(random(), 1, 5), i);
? new Entry(random.nextInt(6), 0)
: new Entry(arr[i - 1].value - TestUtil.nextInt(random, 1, 5), i);
}
},
ASCENDING_SEQUENCES {
@Override
public void set(Entry[] arr, int i) {
public void set(Entry[] arr, int i, Random random) {
arr[i] =
i == 0
? new Entry(random().nextInt(6), 0)
? new Entry(random.nextInt(6), 0)
: new Entry(
rarely() ? random().nextInt(1000) : arr[i - 1].value + random().nextInt(6), i);
rarely(random) ? random.nextInt(1000) : arr[i - 1].value + random.nextInt(6),
i);
}
},
MOSTLY_ASCENDING {
@Override
public void set(Entry[] arr, int i) {
public void set(Entry[] arr, int i, Random random) {
arr[i] =
i == 0
? new Entry(random().nextInt(6), 0)
: new Entry(arr[i - 1].value + TestUtil.nextInt(random(), -8, 10), i);
? new Entry(random.nextInt(6), 0)
: new Entry(arr[i - 1].value + TestUtil.nextInt(random, -8, 10), i);
}
};
public abstract void set(Entry[] arr, int i);
public abstract void set(Entry[] arr, int i, Random random);
}
public void test(Strategy strategy, int length) {
Random random = random();
final Entry[] arr = new Entry[length];
for (int i = 0; i < arr.length; ++i) {
strategy.set(arr, i);
strategy.set(arr, i, random);
}
test(arr);
}

View File

@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.util.Locale;
import java.util.Random;
import org.apache.lucene.util.BaseSortTestCase.Entry;
import org.apache.lucene.util.BaseSortTestCase.Strategy;
/**
* Benchmark for {@link Sorter} implementations.
*
* <p>Run the static {@link #main(String[])} method to start the benchmark.
*/
public class SorterBenchmark {
private static final int ARRAY_LENGTH = 20000;
private static final int RUNS = 10;
private static final int LOOPS = 100;
private enum SorterFactory {
INTRO_SORTER(
"IntroSorter",
(arr, s) -> {
return new ArrayIntroSorter<>(arr, Entry::compareTo);
}),
TIM_SORTER(
"TimSorter",
(arr, s) -> {
return new ArrayTimSorter<>(arr, Entry::compareTo, arr.length / 64);
}),
MERGE_SORTER(
"MergeSorter",
(arr, s) -> {
return new ArrayInPlaceMergeSorter<>(arr, Entry::compareTo);
}),
;
final String name;
final Builder builder;
SorterFactory(String name, Builder builder) {
this.name = name;
this.builder = builder;
}
interface Builder {
Sorter build(Entry[] arr, Strategy strategy);
}
}
public static void main(String[] args) throws Exception {
assert false : "Disable assertions to run the benchmark";
Random random = new Random(System.currentTimeMillis());
long seed = random.nextLong();
System.out.println("WARMUP");
benchmarkSorters(Strategy.RANDOM, random, seed);
System.out.println();
for (Strategy strategy : Strategy.values()) {
System.out.println(strategy);
benchmarkSorters(strategy, random, seed);
}
}
private static void benchmarkSorters(Strategy strategy, Random random, long seed) {
for (SorterFactory sorterFactory : SorterFactory.values()) {
System.out.printf(Locale.ROOT, " %-12s...", sorterFactory.name);
random.setSeed(seed);
benchmarkSorter(strategy, sorterFactory, random);
System.out.println();
}
}
private static void benchmarkSorter(
Strategy strategy, SorterFactory sorterFactory, Random random) {
for (int i = 0; i < RUNS; i++) {
Entry[] original = createArray(strategy, random);
Entry[] clone = original.clone();
Sorter sorter = sorterFactory.builder.build(clone, strategy);
long startTimeNs = System.nanoTime();
for (int j = 0; j < LOOPS; j++) {
System.arraycopy(original, 0, clone, 0, original.length);
sorter.sort(0, clone.length);
}
long timeMs = (System.nanoTime() - startTimeNs) / 1000000;
System.out.printf(Locale.ROOT, "%5d", timeMs);
}
}
private static Entry[] createArray(Strategy strategy, Random random) {
Entry[] arr = new Entry[ARRAY_LENGTH];
for (int i = 0; i < arr.length; ++i) {
strategy.set(arr, i, random);
}
return arr;
}
}