Use multi-select instead of sort for Dynamic Ranges

This commit is contained in:
Houston Putman 2024-10-14 19:10:01 -05:00
parent 282945998d
commit d197f012ef
3 changed files with 491 additions and 56 deletions

View File

@ -0,0 +1,407 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.util.Arrays;
import java.util.Comparator;
import java.util.SplittableRandom;
/**
* Adaptive selection algorithm based on the introspective quick select algorithm. The quick select
* algorithm uses an interpolation variant of Tukey's ninther median-of-medians for pivot, and
* Bentley-McIlroy 3-way partitioning. For the introspective protection, it shuffles the sub-range
* if the max recursive depth is exceeded.
*
* <p>This selection algorithm is fast on most data shapes, especially on nearly sorted data, or
* when k is close to the boundaries. It runs in linear time on average.
*
* @lucene.internal
*/
public abstract class WeightedSelector {
// This selector is used repeatedly by the radix selector for sub-ranges of less than
// 100 entries. This means this selector is also optimized to be fast on small ranges.
// It uses the variant of medians-of-medians and 3-way partitioning, and finishes the
// last tiny range (3 entries or less) with a very specialized sort.
private SplittableRandom random;
protected abstract long getWeight(int i);
protected abstract long getValue(int i);
public final WeightRangeInfo[] select(
int from,
int to,
long rangeTotalValue,
long beforeTotalValue,
long rangeWeight,
long beforeWeight,
double[] kWeights) {
WeightRangeInfo[] kIndexResults = new WeightRangeInfo[kWeights.length];
Arrays.fill(kIndexResults, new WeightRangeInfo(-1, 0, 0));
checkArgs(rangeWeight, beforeWeight, kWeights);
select(
from,
to,
rangeTotalValue,
beforeTotalValue,
rangeWeight,
beforeWeight,
kWeights,
0,
kWeights.length,
kIndexResults,
2 * MathUtil.log(to - from, 2));
return kIndexResults;
}
void checkArgs(long rangeWeight, long beforeWeight, double[] kWeights) {
if (kWeights.length < 1) {
throw new IllegalArgumentException("There must be at least one k to select, none given");
}
Arrays.sort(kWeights);
if (kWeights[0] < beforeWeight) {
throw new IllegalArgumentException("All kWeights must be >= beforeWeight");
}
if (kWeights[kWeights.length - 1] > beforeWeight + rangeWeight) {
throw new IllegalArgumentException("All kWeights must be < beforeWeight + rangeWeight");
}
}
// Visible for testing.
void select(
int from,
int to,
long rangeTotalValue,
long beforeTotalValue,
long rangeWeight,
long beforeWeight,
double[] kWeights,
int kFrom,
int kTo,
WeightRangeInfo[] kIndexResults,
int maxDepth) {
// This code is inspired from IntroSorter#sort, adapted to loop on a single partition.
// For efficiency, we must enter the loop with at least 4 entries to be able to skip
// some boundary tests during the 3-way partitioning.
int size;
if ((size = to - from) > 3) {
if (--maxDepth == -1) {
// Max recursion depth exceeded: shuffle (only once) and continue.
shuffle(from, to);
}
// Pivot selection based on medians.
int last = to - 1;
int mid = (from + last) >>> 1;
int pivot;
if (size <= IntroSorter.SINGLE_MEDIAN_THRESHOLD) {
// Select the pivot with a single median around the middle element.
// Do not take the median between [from, mid, last] because it hurts performance
// if the order is descending in conjunction with the 3-way partitioning.
int range = size >> 2;
pivot = median(mid - range, mid, mid + range);
} else {
// Select the pivot with a variant of the Tukey's ninther median of medians.
// If k is close to the boundaries, select either the lowest or highest median (this variant
// is inspired from the interpolation search).
int range = size >> 3;
int doubleRange = range << 1;
int medianFirst = median(from, from + range, from + doubleRange);
int medianMiddle = median(mid - range, mid, mid + range);
int medianLast = median(last - doubleRange, last - range, last);
double avgWeight = ((double) rangeWeight) / (to - from);
double middleWeight = kWeights[(kFrom + kTo - 1) >> 1];
// Approximate the k we are trying to find by assuming an equal weight amongst values
int middleK = from + (int) ((middleWeight - beforeWeight) / avgWeight);
if (middleK - from < range) {
// k is close to 'from': select the lowest median.
pivot = min(medianFirst, medianMiddle, medianLast);
} else if (to - middleK <= range) {
// k is close to 'to': select the highest median.
pivot = max(medianFirst, medianMiddle, medianLast);
} else {
// Otherwise select the median of medians.
pivot = median(medianFirst, medianMiddle, medianLast);
}
}
// Bentley-McIlroy 3-way partitioning.
setPivot(pivot);
swap(from, pivot);
int i = from;
int j = to;
int p = from + 1;
int q = last;
long leftTotalValue = 0;
long leftWeight = 0;
long rightTotalValue = 0;
long rightWeight = 0;
while (true) {
int leftCmp, rightCmp;
while ((leftCmp = comparePivot(++i)) > 0) {
leftTotalValue += getValue(i);
leftWeight += getWeight(i);
}
while ((rightCmp = comparePivot(--j)) < 0) {
rightTotalValue += getValue(j);
rightWeight += getWeight(j);
}
if (i >= j) {
if (i == j && rightCmp == 0) {
swap(i, p);
}
break;
}
swap(i, j);
if (rightCmp == 0) {
swap(i, p++);
} else {
leftTotalValue += getValue(i);
leftWeight += getWeight(i);
}
if (leftCmp == 0) {
swap(j, q--);
} else {
rightTotalValue += getValue(j);
rightWeight += getWeight(j);
}
}
i = j + 1;
for (int l = from; l < p; ) {
swap(l++, j--);
}
for (int l = last; l > q; ) {
swap(l--, i++);
}
long leftWeightEnd = beforeWeight + leftWeight;
long rightWeightStart = beforeWeight + rangeWeight - rightWeight;
// Select the K weight values contained in the bottom and top partitions.
int topKFrom = kTo;
int bottomKTo = kFrom;
for (int ki = kTo - 1; ki >= kFrom; ki--) {
if (kWeights[ki] >= rightWeightStart) {
topKFrom = ki;
}
if (kWeights[ki] <= leftWeightEnd) {
bottomKTo = ki + 1;
break;
}
}
// Recursively select the relevant k-values from the bottom group, if there are any k-values
// to select there
if (bottomKTo > kFrom) {
select(
from,
j + 1,
leftTotalValue,
beforeTotalValue,
leftWeight,
beforeWeight,
kWeights,
kFrom,
bottomKTo,
kIndexResults,
maxDepth);
}
// Recursively select the relevant k-values from the top group, if there are any k-values to
// select there
if (topKFrom < kTo) {
select(
i,
to,
rightTotalValue,
beforeTotalValue + rangeTotalValue - rightTotalValue,
rightWeight,
beforeWeight + rangeWeight - rightWeight,
kWeights,
topKFrom,
kTo,
kIndexResults,
maxDepth);
}
// Choose the k result indexes for this partition
if (bottomKTo < topKFrom) {
findKIndexes(
j + 1,
i,
beforeTotalValue + leftTotalValue,
beforeWeight + leftWeight,
bottomKTo,
topKFrom,
kWeights,
kIndexResults);
}
}
// Sort the final tiny range (3 entries or less) with a very specialized sort.
switch (size) {
case 1:
kIndexResults[kTo - 1] =
new WeightRangeInfo(
from, beforeTotalValue + getValue(from), beforeWeight + getWeight(from));
break;
case 2:
if (compare(from, from + 1) > 0) {
swap(from, from + 1);
}
findKIndexes(
from, from + 2, beforeTotalValue, beforeWeight, kFrom, kTo, kWeights, kIndexResults);
break;
case 3:
sort3(from);
findKIndexes(
from, from + 3, beforeTotalValue, beforeWeight, kFrom, kTo, kWeights, kIndexResults);
break;
}
}
private void findKIndexes(
int from,
int to,
long beforeTotalValue,
long beforeWeight,
int kFrom,
int kTo,
double[] kWeights,
WeightRangeInfo[] kIndexResults) {
long runningWeight = beforeWeight;
long runningTotalValue = beforeTotalValue;
int kIdx = kFrom;
for (int listIdx = from; listIdx < to && kIdx < kTo; listIdx++) {
runningWeight += getWeight(listIdx);
runningTotalValue += getValue(listIdx);
// Skip ahead in the weight list if the same value is used for multiple weights, we will only
// record a result index for the last weight that matches it.
while (++kIdx < kTo && kWeights[kIdx] <= runningWeight) {}
if (kWeights[--kIdx] <= runningWeight) {
kIndexResults[kIdx] = new WeightRangeInfo(listIdx, runningTotalValue, runningWeight);
// Now that we have recorded the resultIndex for this weight, go to the next weight
kIdx++;
}
}
}
/** Returns the index of the min element among three elements at provided indices. */
private int min(int i, int j, int k) {
if (compare(i, j) <= 0) {
return compare(i, k) <= 0 ? i : k;
}
return compare(j, k) <= 0 ? j : k;
}
/** Returns the index of the max element among three elements at provided indices. */
private int max(int i, int j, int k) {
if (compare(i, j) <= 0) {
return compare(j, k) < 0 ? k : j;
}
return compare(i, k) < 0 ? k : i;
}
/** Copy of {@code IntroSorter#median}. */
private int median(int i, int j, int k) {
if (compare(i, j) < 0) {
if (compare(j, k) <= 0) {
return j;
}
return compare(i, k) < 0 ? k : i;
}
if (compare(j, k) >= 0) {
return j;
}
return compare(i, k) < 0 ? i : k;
}
/**
* Sorts 3 entries starting at from (inclusive). This specialized method is more efficient than
* calling {@link Sorter#insertionSort(int, int)}.
*/
private void sort3(int from) {
final int mid = from + 1;
final int last = from + 2;
if (compare(from, mid) <= 0) {
if (compare(mid, last) > 0) {
swap(mid, last);
if (compare(from, mid) > 0) {
swap(from, mid);
}
}
} else if (compare(mid, last) >= 0) {
swap(from, last);
} else {
swap(from, mid);
if (compare(mid, last) > 0) {
swap(mid, last);
}
}
}
/**
* Shuffles the entries between from (inclusive) and to (exclusive) with Durstenfeld's algorithm.
*/
private void shuffle(int from, int to) {
if (this.random == null) {
this.random = new SplittableRandom();
}
SplittableRandom random = this.random;
for (int i = to - 1; i > from; i--) {
swap(i, random.nextInt(from, i + 1));
}
}
/** Swap values at slots <code>i</code> and <code>j</code>. */
protected abstract void swap(int i, int j);
/**
* Save the value at slot <code>i</code> so that it can later be used as a pivot, see {@link
* #comparePivot(int)}.
*/
protected abstract void setPivot(int i);
/**
* Compare the pivot with the slot at <code>j</code>, similarly to {@link #compare(int, int)
* compare(i, j)}.
*/
protected abstract int comparePivot(int j);
/**
* Compare entries found in slots <code>i</code> and <code>j</code>. The contract for the returned
* value is the same as {@link Comparator#compare(Object, Object)}.
*/
protected int compare(int i, int j) {
setPivot(i);
return comparePivot(j);
}
/**
* Holds information for a returned weight index result
*
* @param index the index at which the weight range limit was found
* @param runningValueSum the sum of values from the start of the list to the end of the range
* limit
* @param runningWeight the sum of weights from the start of the list to the end of the range
* limit
*/
public record WeightRangeInfo(int index, long runningValueSum, long runningWeight) {}
}

View File

@ -28,7 +28,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LongValues;
import org.apache.lucene.search.LongValuesSource;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.apache.lucene.util.WeightedSelector;
/**
* Methods to create dynamic ranges for numeric fields.
@ -66,6 +66,7 @@ public final class DynamicRangeUtil {
matchingDocsList.stream().mapToInt(FacetsCollector.MatchingDocs::totalHits).sum();
long[] values = new long[totalDoc];
long[] weights = new long[totalDoc];
long totalValue = 0;
long totalWeight = 0;
int overallLength = 0;
@ -107,6 +108,7 @@ public final class DynamicRangeUtil {
assert curSegmentOutput.values.length == curSegmentOutput.weights.length;
try {
totalValue = Math.addExact(curSegmentOutput.segmentTotalValue, totalValue);
totalWeight = Math.addExact(curSegmentOutput.segmentTotalWeight, totalWeight);
} catch (ArithmeticException ae) {
throw new IllegalArgumentException(
@ -118,7 +120,8 @@ public final class DynamicRangeUtil {
System.arraycopy(curSegmentOutput.weights, 0, weights, overallLength, currSegmentLen);
overallLength += currSegmentLen;
}
return computeDynamicNumericRanges(values, weights, overallLength, totalWeight, topN);
return computeDynamicNumericRanges(
values, weights, overallLength, totalValue, totalWeight, topN);
}
private static class SegmentTask implements Callable<Void> {
@ -165,6 +168,8 @@ public final class DynamicRangeUtil {
segmentOutput.values[segmentOutput.segmentIdx] = curValue;
segmentOutput.weights[segmentOutput.segmentIdx] = curWeight;
try {
segmentOutput.segmentTotalValue =
Math.addExact(segmentOutput.segmentTotalValue, curValue);
segmentOutput.segmentTotalWeight =
Math.addExact(segmentOutput.segmentTotalWeight, curWeight);
} catch (ArithmeticException ae) {
@ -180,6 +185,7 @@ public final class DynamicRangeUtil {
private static final class SegmentOutput {
private final long[] values;
private final long[] weights;
private long segmentTotalValue;
private long segmentTotalWeight;
private int segmentIdx;
@ -202,7 +208,7 @@ public final class DynamicRangeUtil {
* is used to compute the equi-weight per bin.
*/
public static List<DynamicRangeInfo> computeDynamicNumericRanges(
long[] values, long[] weights, int len, long totalWeight, int topN) {
long[] values, long[] weights, int len, long totalValue, long totalWeight, int topN) {
assert values.length == weights.length && len <= values.length && len >= 0;
assert topN >= 0;
List<DynamicRangeInfo> dynamicRangeResult = new ArrayList<>();
@ -210,16 +216,25 @@ public final class DynamicRangeUtil {
return dynamicRangeResult;
}
new InPlaceMergeSorter() {
@Override
protected int compare(int index1, int index2) {
int cmp = Long.compare(values[index1], values[index2]);
if (cmp == 0) {
// If the values are equal, sort based on the weights.
// Any weight order is correct as long as it's deterministic.
return Long.compare(weights[index1], weights[index2]);
double rangeWeightTarget = (double) totalWeight / topN;
double[] kWeights = new double[topN];
for (int i = 0; i < topN; i++) {
kWeights[i] = (i == 0 ? 0 : kWeights[i - 1]) + rangeWeightTarget;
}
return cmp;
WeightedSelector.WeightRangeInfo[] kIndexResults =
new WeightedSelector() {
private long pivotValue;
private long pivotWeight;
@Override
protected long getWeight(int i) {
return weights[i];
}
@Override
protected long getValue(int i) {
return values[i];
}
@Override
@ -231,37 +246,45 @@ public final class DynamicRangeUtil {
weights[index1] = weights[index2];
weights[index2] = tmp;
}
}.sort(0, len);
long accuWeight = 0;
long valueSum = 0;
int count = 0;
int minIdx = 0;
@Override
protected void setPivot(int i) {
pivotValue = values[i];
pivotWeight = weights[i];
}
double rangeWeightTarget = (double) totalWeight / Math.min(topN, len);
@Override
protected int comparePivot(int j) {
int cmp = Long.compare(pivotValue, values[j]);
if (cmp == 0) {
// If the values are equal, sort based on the weights.
// Any weight order is correct as long as it's deterministic.
return Long.compare(pivotWeight, weights[j]);
}
return cmp;
}
}.select(0, len, totalValue, 0, totalWeight, 0, kWeights);
for (int i = 0; i < len; i++) {
accuWeight += weights[i];
valueSum += values[i];
count++;
if (accuWeight >= rangeWeightTarget) {
int lastIdx = -1;
long lastTotalValue = 0;
long lastTotalWeight = 0;
for (int kIdx = 0; kIdx < topN; kIdx++) {
WeightedSelector.WeightRangeInfo weightRangeInfo = kIndexResults[kIdx];
if (weightRangeInfo.index() > -1) {
int count = weightRangeInfo.index() - lastIdx;
dynamicRangeResult.add(
new DynamicRangeInfo(
count, accuWeight, values[minIdx], values[i], (double) valueSum / count));
count = 0;
accuWeight = 0;
valueSum = 0;
minIdx = i + 1;
count,
(weightRangeInfo.runningWeight() - lastTotalWeight),
values[lastIdx + 1],
values[weightRangeInfo.index()],
(double) (weightRangeInfo.runningValueSum() - lastTotalValue) / count));
lastIdx = weightRangeInfo.index();
lastTotalValue = weightRangeInfo.runningValueSum();
lastTotalWeight = weightRangeInfo.runningWeight();
}
}
// capture the remaining values to create the last range
if (minIdx < len) {
dynamicRangeResult.add(
new DynamicRangeInfo(
count, accuWeight, values[minIdx], values[len - 1], (double) valueSum / count));
}
return dynamicRangeResult;
}

View File

@ -26,10 +26,12 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
long[] values = new long[1000];
long[] weights = new long[1000];
long totalValue = 0;
long totalWeight = 0;
for (int i = 0; i < 1000; i++) {
values[i] = i + 1;
weights[i] = i;
totalValue += i + 1;
totalWeight += i;
}
@ -40,7 +42,8 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
new DynamicRangeUtil.DynamicRangeInfo(159, 125133L, 709L, 867L, 788D));
expectedRangeInfoList.add(
new DynamicRangeUtil.DynamicRangeInfo(133, 124089L, 868L, 1000L, 934D));
assertDynamicNumericRangeResults(values, weights, 4, totalWeight, expectedRangeInfoList);
assertDynamicNumericRangeResults(
values, weights, 4, totalValue, totalWeight, expectedRangeInfoList);
}
public void testComputeDynamicNumericRangesWithSameValues() {
@ -55,11 +58,12 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
}
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(51, 1275L, 50L, 50L, 50D));
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(21, 1281L, 50L, 50L, 50D));
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(16, 1272L, 50L, 50L, 50D));
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(12, 1122L, 50L, 50L, 50D));
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(20, 1210L, 50L, 50L, 50D));
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(16, 1256L, 50L, 50L, 50D));
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(13, 1209L, 50L, 50L, 50D));
assertDynamicNumericRangeResults(values, weights, 4, totalWeight, expectedRangeInfoList);
assertDynamicNumericRangeResults(
values, weights, 4, 50 * values.length, totalWeight, expectedRangeInfoList);
}
public void testComputeDynamicNumericRangesWithOneValue() {
@ -68,7 +72,7 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
List<DynamicRangeUtil.DynamicRangeInfo> expectedRangeInfoList = new ArrayList<>();
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 1L, 50L, 50L, 50D));
assertDynamicNumericRangeResults(values, weights, 4, 1, expectedRangeInfoList);
assertDynamicNumericRangeResults(values, weights, 4, 50, 1, expectedRangeInfoList);
}
public void testComputeDynamicNumericRangesWithOneLargeWeight() {
@ -80,24 +84,25 @@ public class TestDynamicRangeUtil extends LuceneTestCase {
expectedRangeInfoList.add(new DynamicRangeUtil.DynamicRangeInfo(1, 52343, 14L, 14L, 14D));
expectedRangeInfoList.add(
new DynamicRangeUtil.DynamicRangeInfo(6, 2766, 32L, 455L, 163.16666666666666D));
assertDynamicNumericRangeResults(values, weights, 4, 55109, expectedRangeInfoList);
assertDynamicNumericRangeResults(values, weights, 4, 993, 55109, expectedRangeInfoList);
}
private static void assertDynamicNumericRangeResults(
long[] values,
long[] weights,
int topN,
long totalValue,
long totalWeight,
List<DynamicRangeUtil.DynamicRangeInfo> expectedDynamicRangeResult) {
List<DynamicRangeUtil.DynamicRangeInfo> mockDynamicRangeResult =
DynamicRangeUtil.computeDynamicNumericRanges(
values, weights, values.length, totalWeight, topN);
assertTrue(compareDynamicRangeResult(mockDynamicRangeResult, expectedDynamicRangeResult));
values, weights, values.length, totalValue, totalWeight, topN);
compareDynamicRangeResult(mockDynamicRangeResult, expectedDynamicRangeResult);
}
private static boolean compareDynamicRangeResult(
private static void compareDynamicRangeResult(
List<DynamicRangeUtil.DynamicRangeInfo> mockResult,
List<DynamicRangeUtil.DynamicRangeInfo> expectedResult) {
return mockResult.size() == expectedResult.size() && mockResult.containsAll(expectedResult);
assertEquals(expectedResult, mockResult);
}
}