LUCENE-10225: Improve IntroSelector with 3-way partitioning.

This commit is contained in:
Bruno Roustant 2021-11-17 10:38:27 +01:00 committed by GitHub
parent c0112dd2ff
commit c71cbac4f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 328 additions and 165 deletions

View File

@ -51,7 +51,8 @@ Improvements
Optimizations
---------------------
(No changes)
* LUCENE-10225: Improve IntroSelector with 3-ways partitioning. (Bruno Roustant, Adrien Grand)
Bug Fixes
---------------------

View File

@ -17,173 +17,200 @@
package org.apache.lucene.util;
import java.util.Comparator;
import java.util.SplittableRandom;
/**
* Implementation of the quick select algorithm.
* Adaptive selection algorithm based on the introspective quick select algorithm. The quick select
* algorithm uses an interpolation variant of Tukey's ninther median-of-medians for pivot, and
* Bentley-McIlroy 3-way partitioning. For the introspective protection, it shuffles the sub-range
* if the max recursive depth is exceeded.
*
* <p>It uses the median of the first, middle and last values as a pivot and falls back to a median
* of medians when the number of recursion levels exceeds {@code 2 lg(n)}, as a consequence it runs
* in linear time on average.
* <p>This selection algorithm is fast on most data shapes, especially on nearly sorted data, or
* when k is close to the boundaries. It runs in linear time on average.
*
* @lucene.internal
*/
public abstract class IntroSelector extends Selector {
// This selector is used repeatedly by the radix selector for sub-ranges of less than
// 100 entries. This means this selector is also optimized to be fast on small ranges.
// It uses the variant of medians-of-medians and 3-way partitioning, and finishes the
// last tiny range (3 entries or less) with a very specialized sort.
private SplittableRandom random;
@Override
public final void select(int from, int to, int k) {
checkArgs(from, to, k);
final int maxDepth = 2 * MathUtil.log(to - from, 2);
quickSelect(from, to, k, maxDepth);
select(from, to, k, 2 * MathUtil.log(to - from, 2));
}
int slowSelect(int from, int to, int k) {
return medianOfMediansSelect(from, to - 1, k);
}
// Visible for testing.
void select(int from, int to, int k, int maxDepth) {
// This code is inspired from IntroSorter#sort, adapted to loop on a single partition.
int medianOfMediansSelect(int left, int right, int k) {
do {
// Defensive check, this is also checked in the calling
// method. Including here so this method can be used
// as a self contained quickSelect implementation.
if (left == right) {
return left;
// For efficiency, we must enter the loop with at least 4 entries to be able to skip
// some boundary tests during the 3-way partitioning.
int size;
while ((size = to - from) > 3) {
if (--maxDepth == -1) {
// Max recursion depth exceeded: shuffle (only once) and continue.
shuffle(from, to);
}
int pivotIndex = pivot(left, right);
pivotIndex = partition(left, right, k, pivotIndex);
if (k == pivotIndex) {
return k;
} else if (k < pivotIndex) {
right = pivotIndex - 1;
// Pivot selection based on medians.
int last = to - 1;
int mid = (from + last) >>> 1;
int pivot;
if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) {
// Select the pivot with a single median around the middle element.
// Do not take the median between [from, mid, last] because it hurts performance
// if the order is descending in conjunction with the 3-way partitioning.
int range = size >> 2;
pivot = median(mid - range, mid, mid + range);
} else {
left = pivotIndex + 1;
}
} while (left != right);
return left;
}
private int partition(int left, int right, int k, int pivotIndex) {
setPivot(pivotIndex);
swap(pivotIndex, right);
int storeIndex = left;
for (int i = left; i < right; i++) {
if (comparePivot(i) > 0) {
swap(storeIndex, i);
storeIndex++;
}
}
int storeIndexEq = storeIndex;
for (int i = storeIndex; i < right; i++) {
if (comparePivot(i) == 0) {
swap(storeIndexEq, i);
storeIndexEq++;
}
}
swap(right, storeIndexEq);
if (k < storeIndex) {
return storeIndex;
} else if (k <= storeIndexEq) {
return k;
}
return storeIndexEq;
}
private int pivot(int left, int right) {
if (right - left < 5) {
int pivotIndex = partition5(left, right);
return pivotIndex;
}
for (int i = left; i <= right; i = i + 5) {
int subRight = i + 4;
if (subRight > right) {
subRight = right;
}
int median5 = partition5(i, subRight);
swap(median5, left + ((i - left) / 5));
}
int mid = ((right - left) / 10) + left + 1;
int to = left + ((right - left) / 5);
return medianOfMediansSelect(left, to, mid);
}
// selects the median of a group of at most five elements,
// implemented using insertion sort. Efficient due to
// bounded nature of data set.
private int partition5(int left, int right) {
int i = left + 1;
while (i <= right) {
int j = i;
while (j > left && compare(j - 1, j) > 0) {
swap(j - 1, j);
j--;
}
i++;
}
return (left + right) >>> 1;
}
private void quickSelect(int from, int to, int k, int maxDepth) {
assert from <= k;
assert k < to;
if (to - from == 1) {
return;
}
if (--maxDepth < 0) {
slowSelect(from, to, k);
return;
}
final int mid = (from + to) >>> 1;
// heuristic: we use the median of the values at from, to-1 and mid as a pivot
if (compare(from, to - 1) > 0) {
swap(from, to - 1);
}
if (compare(to - 1, mid) > 0) {
swap(to - 1, mid);
if (compare(from, to - 1) > 0) {
swap(from, to - 1);
}
}
setPivot(to - 1);
int left = from + 1;
int right = to - 2;
for (; ; ) {
while (comparePivot(left) > 0) {
++left;
// Select the pivot with a variant of the Tukey's ninther median of medians.
// If k is close to the boundaries, select either the lowest or highest median (this variant
// is inspired from the interpolation search).
int range = size >> 3;
int doubleRange = range << 1;
int medianFirst = median(from, from + range, from + doubleRange);
int medianMiddle = median(mid - range, mid, mid + range);
int medianLast = median(last - doubleRange, last - range, last);
if (k - from < range) {
// k is close to 'from': select the lowest median.
pivot = min(medianFirst, medianMiddle, medianLast);
} else if (to - k <= range) {
// k is close to 'to': select the highest median.
pivot = max(medianFirst, medianMiddle, medianLast);
} else {
// Otherwise select the median of medians.
pivot = median(medianFirst, medianMiddle, medianLast);
}
}
while (left < right && comparePivot(right) <= 0) {
--right;
// Bentley-McIlroy 3-way partitioning.
setPivot(pivot);
swap(from, pivot);
int i = from;
int j = to;
int p = from + 1;
int q = last;
while (true) {
int leftCmp, rightCmp;
while ((leftCmp = comparePivot(++i)) > 0) {}
while ((rightCmp = comparePivot(--j)) < 0) {}
if (i >= j) {
if (i == j && rightCmp == 0) {
swap(i, p);
}
break;
}
swap(i, j);
if (rightCmp == 0) {
swap(i, p++);
}
if (leftCmp == 0) {
swap(j, q--);
}
}
i = j + 1;
for (int l = from; l < p; ) {
swap(l++, j--);
}
for (int l = last; l > q; ) {
swap(l--, i++);
}
if (left < right) {
swap(left, right);
--right;
// Select the partition containing the k-th element.
if (k <= j) {
to = j + 1;
} else if (k >= i) {
from = i;
} else {
return;
}
}
// Sort the final tiny range (3 entries or less) with a very specialized sort.
switch (size) {
case 2:
if (compare(from, from + 1) > 0) {
swap(from, from + 1);
}
break;
case 3:
sort3(from);
break;
}
}
swap(left, to - 1);
}
if (left == k) {
return;
} else if (left < k) {
quickSelect(left + 1, to, k, maxDepth);
/** Returns the index of the min element among three elements at provided indices. */
private int min(int i, int j, int k) {
if (compare(i, j) <= 0) {
return compare(i, k) <= 0 ? i : k;
}
return compare(j, k) <= 0 ? j : k;
}
/** Returns the index of the max element among three elements at provided indices. */
private int max(int i, int j, int k) {
if (compare(i, j) <= 0) {
return compare(j, k) < 0 ? k : j;
}
return compare(i, k) < 0 ? k : i;
}
/** Copy of {@code IntroSorter#median}. */
private int median(int i, int j, int k) {
if (compare(i, j) < 0) {
if (compare(j, k) <= 0) {
return j;
}
return compare(i, k) < 0 ? k : i;
}
if (compare(j, k) >= 0) {
return j;
}
return compare(i, k) < 0 ? i : k;
}
/**
* Sorts 3 entries starting at from (inclusive). This specialized method is more efficient than
* calling {@link Sorter#insertionSort(int, int)}.
*/
private void sort3(int from) {
final int mid = from + 1;
final int last = from + 2;
if (compare(from, mid) <= 0) {
if (compare(mid, last) > 0) {
swap(mid, last);
if (compare(from, mid) > 0) {
swap(from, mid);
}
}
} else if (compare(mid, last) >= 0) {
swap(from, last);
} else {
quickSelect(from, left, k, maxDepth);
swap(from, mid);
if (compare(mid, last) > 0) {
swap(mid, last);
}
}
}
/**
* Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
* value is the same as {@link Comparator#compare(Object, Object)}.
* Shuffles the entries between from (inclusive) and to (exclusive) with Durstenfeld's algorithm.
*/
protected int compare(int i, int j) {
setPivot(i);
return comparePivot(j);
private void shuffle(int from, int to) {
if (this.random == null) {
this.random = new SplittableRandom();
}
SplittableRandom random = this.random;
for (int i = to - 1; i > from; i--) {
swap(i, random.nextInt(from, i + 1));
}
}
/**
@ -197,4 +224,13 @@ public abstract class IntroSelector extends Selector {
* compare(i, j)}.
*/
protected abstract int comparePivot(int j);
/**
* Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
* value is the same as {@link Comparator#compare(Object, Object)}.
*/
protected int compare(int i, int j) {
setPivot(i);
return comparePivot(j);
}
}

View File

@ -20,7 +20,9 @@ package org.apache.lucene.util;
* {@link Sorter} implementation based on a variant of the quicksort algorithm called <a
* href="http://en.wikipedia.org/wiki/Introsort">introsort</a>: when the recursion level exceeds the
* log of the length of the array to sort, it falls back to heapsort. This prevents quicksort from
* running into its worst-case quadratic runtime. Small ranges are sorted with insertion sort.
* running into its worst-case quadratic runtime. Selects the pivot using Tukey's ninther
* median-of-medians, and partitions using Bentley-McIlroy 3-way partitioning. Small ranges are
* sorted with insertion sort.
*
* <p>This sort algorithm is fast on most data shapes, especially with low cardinality. If the data
* to sort is known to be strictly ascending or descending, prefer {@link TimSorter}.
@ -30,7 +32,7 @@ package org.apache.lucene.util;
public abstract class IntroSorter extends Sorter {
/** Below this size threshold, the partition selection is simplified to a single median. */
private static final int SINGLE_MEDIAN_THRESHOLD = 40;
static final int SINGLE_MEDIAN_THRESHOLD = 40;
/** Create a new {@link IntroSorter}. */
public IntroSorter() {}
@ -49,13 +51,12 @@ public abstract class IntroSorter extends Sorter {
* algorithm (Engineering a Sort Function, Bentley-McIlroy).
*/
void sort(int from, int to, int maxDepth) {
int size;
// Sort small ranges with insertion sort.
int size;
while ((size = to - from) > INSERTION_SORT_THRESHOLD) {
if (--maxDepth < 0) {
// Max recursion depth reached: fallback to heap sort.
// Max recursion depth exceeded: fallback to heap sort.
heapSort(from, to);
return;
}
@ -67,11 +68,11 @@ public abstract class IntroSorter extends Sorter {
if (size <= SINGLE_MEDIAN_THRESHOLD) {
// Select the pivot with a single median around the middle element.
// Do not take the median between [from, mid, last] because it hurts performance
// if the order is descending.
// if the order is descending in conjunction with the 3-way partitioning.
int range = size >> 2;
pivot = median(mid - range, mid, mid + range);
} else {
// Select the pivot with the median of medians.
// Select the pivot with the Tukey's ninther median of medians.
int range = size >> 3;
int doubleRange = range << 1;
int medianFirst = median(from, from + range, from + doubleRange);

View File

@ -30,7 +30,10 @@ public final class MathUtil {
* @param base must be {@code > 1}
*/
public static int log(long x, int base) {
if (base <= 1) {
if (base == 2) {
// This specialized method is 30x faster.
return x <= 0 ? 0 : 63 - Long.numberOfLeadingZeros(x);
} else if (base <= 1) {
throw new IllegalArgumentException("base must be > 1");
}
int ret = 0;

View File

@ -0,0 +1,126 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.util.Locale;
import java.util.Random;
import org.apache.lucene.util.BaseSortTestCase.Entry;
import org.apache.lucene.util.BaseSortTestCase.Strategy;
/**
* Benchmark for {@link Selector} implementations.
*
* <p>Run the static {@link #main(String[])} method to start the benchmark.
*/
public class SelectorBenchmark {
private static final int ARRAY_LENGTH = 20000;
private static final int RUNS = 10;
private static final int LOOPS = 800;
private enum SelectorFactory {
INTRO_SELECTOR(
"IntroSelector",
(arr, s) -> {
return new IntroSelector() {
Entry pivot;
@Override
protected void swap(int i, int j) {
ArrayUtil.swap(arr, i, j);
}
@Override
protected void setPivot(int i) {
pivot = arr[i];
}
@Override
protected int comparePivot(int j) {
return pivot.compareTo(arr[j]);
}
};
}),
;
final String name;
final Builder builder;
SelectorFactory(String name, Builder builder) {
this.name = name;
this.builder = builder;
}
interface Builder {
Selector build(Entry[] arr, Strategy strategy);
}
}
public static void main(String[] args) throws Exception {
assert false : "Disable assertions to run the benchmark";
Random random = new Random(System.currentTimeMillis());
long seed = random.nextLong();
System.out.println("WARMUP");
benchmarkSelectors(Strategy.RANDOM, random, seed);
System.out.println();
for (Strategy strategy : Strategy.values()) {
System.out.println(strategy);
benchmarkSelectors(strategy, random, seed);
}
}
private static void benchmarkSelectors(Strategy strategy, Random random, long seed) {
for (SelectorFactory selectorFactory : SelectorFactory.values()) {
System.out.printf(Locale.ROOT, " %-15s...", selectorFactory.name);
random.setSeed(seed);
benchmarkSelector(strategy, selectorFactory, random);
System.out.println();
}
}
private static void benchmarkSelector(
Strategy strategy, SelectorFactory selectorFactory, Random random) {
for (int i = 0; i < RUNS; i++) {
Entry[] original = createArray(strategy, random);
Entry[] clone = original.clone();
Selector selector = selectorFactory.builder.build(clone, strategy);
long startTimeNs = System.nanoTime();
int k = random.nextInt(clone.length);
int kIncrement = random.nextInt(clone.length / 14) * 2 + 1;
for (int j = 0; j < LOOPS; j++) {
System.arraycopy(original, 0, clone, 0, original.length);
selector.select(0, clone.length, k);
k += kIncrement;
if (k >= clone.length) {
k -= clone.length;
}
}
long timeMs = (System.nanoTime() - startTimeNs) / 1000000;
System.out.printf(Locale.ROOT, "%5d", timeMs);
}
}
private static Entry[] createArray(Strategy strategy, Random random) {
Entry[] arr = new Entry[ARRAY_LENGTH];
for (int i = 0; i < arr.length; ++i) {
strategy.set(arr, i, random);
}
return arr;
}
}

View File

@ -17,30 +17,26 @@
package org.apache.lucene.util;
import java.util.Arrays;
import java.util.Random;
public class TestIntroSelector extends LuceneTestCase {
public void testSelect() {
Random random = random();
for (int iter = 0; iter < 100; ++iter) {
doTestSelect(false);
doTestSelect(random);
}
}
public void testSlowSelect() {
for (int iter = 0; iter < 100; ++iter) {
doTestSelect(true);
}
}
private void doTestSelect(boolean slow) {
final int from = random().nextInt(5);
final int to = from + TestUtil.nextInt(random(), 1, 10000);
final int max = random().nextBoolean() ? random().nextInt(100) : random().nextInt(100000);
Integer[] arr = new Integer[to + random().nextInt(5)];
private void doTestSelect(Random random) {
final int from = random.nextInt(5);
final int to = from + TestUtil.nextInt(random, 1, 10000);
final int max = random.nextBoolean() ? random.nextInt(100) : random.nextInt(100000);
Integer[] arr = new Integer[to + random.nextInt(5)];
for (int i = 0; i < arr.length; ++i) {
arr[i] = TestUtil.nextInt(random(), 0, max);
arr[i] = TestUtil.nextInt(random, 0, max);
}
final int k = TestUtil.nextInt(random(), from, to - 1);
final int k = TestUtil.nextInt(random, from, to - 1);
Integer[] expected = arr.clone();
Arrays.sort(expected, from, to);
@ -66,10 +62,10 @@ public class TestIntroSelector extends LuceneTestCase {
return pivot.compareTo(actual[j]);
}
};
if (slow) {
selector.slowSelect(from, to, k);
} else {
if (random.nextBoolean()) {
selector.select(from, to, k);
} else {
selector.select(from, to, k, random.nextInt(3));
}
assertEquals(expected[k], actual[k]);
@ -77,9 +73,9 @@ public class TestIntroSelector extends LuceneTestCase {
if (i < from || i >= to) {
assertSame(arr[i], actual[i]);
} else if (i <= k) {
assertTrue(actual[i].intValue() <= actual[k].intValue());
assertTrue(actual[i] <= actual[k]);
} else {
assertTrue(actual[i].intValue() >= actual[k].intValue());
assertTrue(actual[i] >= actual[k]);
}
}
}